Skip to content

Commit

Permalink
Rename op kernel type control 'supported types' to 'default types'. (m…
Browse files Browse the repository at this point in the history
…icrosoft#6886)

Cleaning up some naming in the op kernel type control infrastructure.
"Supported types" was a bit semantically overloaded. Renamed it to "default types". They are the types that are supported by default.
  • Loading branch information
edgchen1 authored Mar 9, 2021
1 parent 67c6740 commit 73fe1f2
Showing 13 changed files with 187 additions and 182 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/generator/constant_of_shape.cc
Original file line number Diff line number Diff line change
@@ -7,15 +7,15 @@
namespace onnxruntime {

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0,
ConstantOfShapeDefaultOutputTypes);
}

namespace {

using SupportedOutputTypes =
ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(
using OutputTypes =
ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0);

using EnabledOutputTypes =
@@ -70,7 +70,7 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2",
BuildKernelDefConstraintsFromTypeList<SupportedOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<OutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledOutputTypes>()),
ConstantOfShape);

16 changes: 10 additions & 6 deletions onnxruntime/core/providers/cpu/math/clip.cc
Original file line number Diff line number Diff line change
@@ -18,18 +18,22 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Clip_6<float>);

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0,
float);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0,
float, double, int8_t, uint8_t, int64_t, uint64_t);
} // namespace op_kernel_type_control

using Clip11Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using Clip12Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);
using EnabledClip12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);
using Clip11Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using Clip12Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);
using EnabledClip12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);

using AllEnabledClipTypes =
utils::TypeSetUnion<
36 changes: 18 additions & 18 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
@@ -15,48 +15,48 @@ namespace onnxruntime {
// Supported types for operators that have type reduction enabled
namespace op_kernel_type_control {
// Max
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double);

ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
int64_t);

// Min
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
int64_t);

// Pow
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);

// Pow 12 and later has separate Base and Exponent types.
// To reduce templatization we choose to support a subset of types for the base and exponent.
// This gives us 16 permutations.
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 0, int32_t, int64_t, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 1, int32_t, int64_t, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 0, int32_t, int64_t, float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 12,
Input, 1, int32_t, int64_t, float, double);
} // namespace op_kernel_type_control

//
// reduce the supported type lists to what's allowed in this build
//
using Max8Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0);
using Max12Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0);
using Max8Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0);
using Max12Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0);
using EnabledMax8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0);
using EnabledMax12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0);

using Min8Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
using Min12Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
using Min8Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
using Min12Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);

using Pow7Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
using Pow12BaseTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 0);
using Pow12ExpTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 1);
using Pow7Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
using Pow12BaseTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 0);
using Pow12ExpTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 1);
using EnabledPow7Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
using EnabledPow12BaseTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
Pow, 12, Input, 0);
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/nn/pool.cc
Original file line number Diff line number Diff line change
@@ -14,23 +14,23 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0,
float,
double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0,
double,
float,
int8_t,
uint8_t);
} // namespace op_kernel_type_control

using MaxPool8DataTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
using MaxPool8DataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0);
using EnabledMaxPool8DataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0);
using MaxPool12DataTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
using MaxPool12DataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0);
using EnabledMaxPool12DataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0);
29 changes: 15 additions & 14 deletions onnxruntime/core/providers/cpu/tensor/cast_op.cc
Original file line number Diff line number Diff line change
@@ -30,23 +30,23 @@ namespace onnxruntime {

namespace op_kernel_type_control {
// we're using one set of types for all opsets of Cast
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
int64_t);

ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
} // namespace op_kernel_type_control

namespace {
using SupportedSrcTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Input, 0);
using SupportedDstTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Output, 0);
using SrcTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Input, 0);
using DstTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Output, 0);
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Cast, Input, 0);
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
@@ -282,8 +282,9 @@ template <typename TSrc>
struct SrcDispatcher {
void operator()(
int32_t to, const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) {
using DstTypes = boost::mp11::mp_remove_if_q<EnabledDstTypes, boost::mp11::mp_bind_front<std::is_same, TSrc>>;
utils::MLTypeCallDispatcherFromTypeList<DstTypes> dispatcher{to};
using EnabledDstTypesWithoutSrcType =
boost::mp11::mp_remove_if_q<EnabledDstTypes, boost::mp11::mp_bind_front<std::is_same, TSrc>>;
utils::MLTypeCallDispatcherFromTypeList<EnabledDstTypesWithoutSrcType> dispatcher{to};
dispatcher.template InvokeWithLeadingTemplateArgs<Dispatcher, TypeList<TSrc>>(context, shape, src, dst);
}
};
@@ -311,8 +312,8 @@ Status Cast::Compute(OpKernelContext* context) const {
return Status::OK();
}

const auto supported_src_type_constraints = BuildKernelDefConstraintsFromTypeList<SupportedSrcTypes>();
const auto supported_dst_type_constraints = BuildKernelDefConstraintsFromTypeList<SupportedDstTypes>();
const auto src_type_constraints = BuildKernelDefConstraintsFromTypeList<SrcTypes>();
const auto dst_type_constraints = BuildKernelDefConstraintsFromTypeList<DstTypes>();
const auto enabled_src_type_constraints = BuildKernelDefConstraintsFromTypeList<EnabledSrcTypes>();
const auto enabled_dst_type_constraints = BuildKernelDefConstraintsFromTypeList<EnabledDstTypes>();

@@ -323,17 +324,17 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
6,
12,
KernelDefBuilder()
.TypeConstraint("T1", supported_src_type_constraints, enabled_src_type_constraints)
.TypeConstraint("T2", supported_dst_type_constraints, enabled_dst_type_constraints)
.TypeConstraint("T1", src_type_constraints, enabled_src_type_constraints)
.TypeConstraint("T2", dst_type_constraints, enabled_dst_type_constraints)
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);

ONNX_CPU_OPERATOR_KERNEL(
Cast,
13,
KernelDefBuilder()
.TypeConstraint("T1", supported_src_type_constraints, enabled_src_type_constraints)
.TypeConstraint("T2", supported_dst_type_constraints, enabled_dst_type_constraints)
.TypeConstraint("T1", src_type_constraints, enabled_src_type_constraints)
.TypeConstraint("T2", dst_type_constraints, enabled_dst_type_constraints)
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);

14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cpu/tensor/gather.cc
Original file line number Diff line number Diff line change
@@ -11,19 +11,19 @@
namespace onnxruntime {

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1, int32_t, int64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1, int64_t);
}

namespace {
using SupportedIndexTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Gather, Input, 1);
using IndexTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Gather, Input, 1);
using EnabledIndexTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Gather, Input, 1);

const auto supported_index_type_constraints = BuildKernelDefConstraintsFromTypeList<SupportedIndexTypes>();
const auto index_type_constraints = BuildKernelDefConstraintsFromTypeList<IndexTypes>();
const auto enabled_index_type_constraints = BuildKernelDefConstraintsFromTypeList<EnabledIndexTypes>();
} // namespace

@@ -33,7 +33,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
10,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", supported_index_type_constraints, enabled_index_type_constraints),
.TypeConstraint("Tind", index_type_constraints, enabled_index_type_constraints),
Gather);

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
@@ -42,15 +42,15 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
12,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", supported_index_type_constraints, enabled_index_type_constraints),
.TypeConstraint("Tind", index_type_constraints, enabled_index_type_constraints),
Gather);

ONNX_CPU_OPERATOR_KERNEL(
Gather,
13,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", supported_index_type_constraints, enabled_index_type_constraints),
.TypeConstraint("Tind", index_type_constraints, enabled_index_type_constraints),
Gather);

Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/cpu/tensor/isinf.cc
Original file line number Diff line number Diff line change
@@ -14,18 +14,18 @@ namespace onnxruntime {
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
float, double);
} // namespace op_kernel_type_control

class IsInf final : public OpKernel {
public:
using SupportedTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);
using DataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);

using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);

explicit IsInf(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
@@ -40,8 +40,8 @@ ONNX_CPU_OPERATOR_KERNEL(
10,
KernelDefBuilder()
.TypeConstraint("T1",
BuildKernelDefConstraintsFromTypeList<IsInf::SupportedTypes>(),
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledTypes>())
BuildKernelDefConstraintsFromTypeList<IsInf::DataTypes>(),
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

@@ -92,7 +92,7 @@ Status IsInf::Compute(OpKernelContext* context) const {

using namespace isinf_internal;

utils::MLTypeCallDispatcherFromTypeList<EnabledTypes> dispatcher{X.GetElementType()};
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{X.GetElementType()};
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);

return Status::OK();
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/pad.cc
Original file line number Diff line number Diff line change
@@ -39,12 +39,12 @@ ONNX_OPERATOR_KERNEL_EX(Pad,
#endif

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0,
float,
double);

ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0,
float,
double,
@@ -56,11 +56,11 @@ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
uint8_t);
} // namespace op_kernel_type_control

using Pad2Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
using Pad2Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0);
using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0);
using Pad11Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
using Pad11Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
using EnabledPad11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
Loading

0 comments on commit 73fe1f2

Please sign in to comment.