Skip to content

Commit

Permalink
Enable type reduction for Clip, MaxPool, and Pad CPU kernels. (micros…
Browse files Browse the repository at this point in the history
edgchen1 authored Mar 8, 2021
1 parent b6c4a7a commit 15d81fb
Showing 4 changed files with 172 additions and 61 deletions.
73 changes: 52 additions & 21 deletions onnxruntime/core/providers/cpu/math/clip.cc
Original file line number Diff line number Diff line change
@@ -2,7 +2,10 @@
// Licensed under the MIT License.

#include "core/providers/cpu/math/clip.h"

#include "core/framework/data_types_internal.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
@@ -14,27 +17,56 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Clip_6<float>);

#define REG_KERNEL_VERSIONED_NONTEMPL(OP_TYPE, START_VER, END_VER, KERNEL_CLASS, ...) \
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(OP_TYPE, \
START_VER, \
END_VER, \
KernelDefBuilder() \
.MayInplace(0, 0) \
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()), \
KERNEL_CLASS);

#define REG_KERNEL_NONTEMPL(OP_TYPE, VERSION, KERNEL_CLASS, ...) \
ONNX_CPU_OPERATOR_KERNEL( \
OP_TYPE, \
VERSION, \
KernelDefBuilder() \
.MayInplace(0, 0) \
.TypeConstraint("T", BuildKernelDefConstraints<__VA_ARGS__>()), \
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0,
float);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0,
float, double, int8_t, uint8_t, int64_t, uint64_t);
} // namespace op_kernel_type_control

using Clip11Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 11, Input, 0);
using Clip12Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);
using EnabledClip12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0);

using AllEnabledClipTypes =
utils::TypeSetUnion<
EnabledClip11Types,
EnabledClip12Types>;

#define REG_KERNEL_VERSIONED_NONTEMPL( \
OP_TYPE, START_VER, END_VER, KERNEL_CLASS, DEFAULT_TYPE_LIST, ENABLED_TYPE_LIST) \
ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \
OP_TYPE, \
START_VER, \
END_VER, \
KernelDefBuilder() \
.MayInplace(0, 0) \
.TypeConstraint("T", \
BuildKernelDefConstraintsFromTypeList<DEFAULT_TYPE_LIST>(), \
BuildKernelDefConstraintsFromTypeList<ENABLED_TYPE_LIST>()), \
KERNEL_CLASS);

REG_KERNEL_VERSIONED_NONTEMPL(Clip, 11, 11, Clip, float);
REG_KERNEL_VERSIONED_NONTEMPL(Clip, 12, 12, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t);
REG_KERNEL_NONTEMPL(Clip, 13, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t);
#define REG_KERNEL_NONTEMPL( \
OP_TYPE, VERSION, KERNEL_CLASS, DEFAULT_TYPE_LIST, ENABLED_TYPE_LIST) \
ONNX_CPU_OPERATOR_KERNEL( \
OP_TYPE, \
VERSION, \
KernelDefBuilder() \
.MayInplace(0, 0) \
.TypeConstraint("T", \
BuildKernelDefConstraintsFromTypeList<DEFAULT_TYPE_LIST>(), \
BuildKernelDefConstraintsFromTypeList<ENABLED_TYPE_LIST>()), \
KERNEL_CLASS);

REG_KERNEL_VERSIONED_NONTEMPL(Clip, 11, 11, Clip, Clip11Types, EnabledClip11Types);
REG_KERNEL_VERSIONED_NONTEMPL(Clip, 12, 12, Clip, Clip12Types, EnabledClip12Types);
REG_KERNEL_NONTEMPL(Clip, 13, Clip, Clip12Types, EnabledClip12Types);

#undef REG_KERNEL_VERSIONED_NONTEMPL
#undef REG_KERNEL_NONTEMPL

template <typename T>
Status Clip_6<T>::Compute(OpKernelContext* ctx) const {
@@ -74,8 +106,7 @@ Status Clip::Compute(OpKernelContext* ctx) const {
const auto* max = ctx->Input<Tensor>(2);
Tensor* Y = ctx->Output(0, X->Shape());

utils::MLTypeCallDispatcher<float, double, int8_t, uint8_t, int64_t, uint64_t>
t_disp(X->GetElementType());
utils::MLTypeCallDispatcherFromTypeList<AllEnabledClipTypes> t_disp(X->GetElementType());

t_disp.Invoke<ComputeImpl>(X, min, max, Y);

57 changes: 44 additions & 13 deletions onnxruntime/core/providers/cpu/nn/pool.cc
Original file line number Diff line number Diff line change
@@ -2,14 +2,44 @@
// Licensed under the MIT License.

#include "core/providers/cpu/nn/pool.h"

#include "core/framework/data_types_internal.h"
#include "core/platform/threadpool.h"
#include "pool_functors.h"
#include "core/providers/cpu/nn/pool_functors.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"

using namespace ::onnxruntime::common;

namespace onnxruntime {

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0,
float,
double);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0,
double,
float,
int8_t,
uint8_t);
} // namespace op_kernel_type_control

using MaxPool8DataTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0);
using EnabledMaxPool8DataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 8, Input, 0);
using MaxPool12DataTypes = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0);
using EnabledMaxPool12DataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, MaxPool, 12, Input, 0);

using AllEnabledMaxPoolDataTypes =
utils::TypeSetUnion<
EnabledMaxPool8DataTypes,
EnabledMaxPool12DataTypes>;

template <typename T>
inline static void RunLoop(concurrency::ThreadPool* tp, std::ptrdiff_t total_channels, T&& task) {
concurrency::ThreadPool::TryParallelFor(tp, total_channels, task.Cost(), task);
@@ -130,9 +160,8 @@ Status Pool<float, AveragePool>::Compute(OpKernelContext* context) const {
pool_attrs_.count_include_pad ? MlasAveragePoolingIncludePad : MlasAveragePoolingExcludePad);
}


Status MaxPoolV8::Compute(OpKernelContext* context) const {
utils::MLTypeCallDispatcher<float, double, int8_t, uint8_t>
utils::MLTypeCallDispatcherFromTypeList<AllEnabledMaxPoolDataTypes>
t_disp(context->Input<Tensor>(0)->GetElementType());
return t_disp.InvokeRet<Status, ComputeHelper>(this, context);
}
@@ -239,19 +268,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(MaxPool, 1, 7,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Pool<float, MaxPool<1 /*VERSION*/>>);

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(MaxPool, 8, 11,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
MaxPoolV8);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(MaxPool, 8, 11,
KernelDefBuilder()
.TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<MaxPool8DataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledMaxPool8DataTypes>())
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
MaxPoolV8);

ONNX_CPU_OPERATOR_KERNEL(MaxPool, 12,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>()})
.TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<MaxPool12DataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledMaxPool12DataTypes>())
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
MaxPoolV8);

97 changes: 70 additions & 27 deletions onnxruntime/core/providers/cpu/tensor/pad.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/pad.h"

#include "core/common/optional.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#include "core/util/math.h"

// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
#ifdef _MSC_VER
#pragma warning(disable : 4996)
#endif
#include "core/util/math.h"
#include "core/providers/cpu/tensor/pad.h"
#include "core/providers/cpu/tensor/utils.h"

namespace onnxruntime {

@@ -33,12 +38,46 @@ ONNX_OPERATOR_KERNEL_EX(Pad,

#endif

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0,
float,
double);

ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0,
float,
double,
int32_t,
int64_t,
uint32_t,
uint64_t,
int8_t,
uint8_t);
} // namespace op_kernel_type_control

using Pad2Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0);
using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 2, Input, 0);
using Pad11Types = ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
using EnabledPad11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);

using AllEnabledPadTypes =
utils::TypeSetUnion<
EnabledPad2Types,
EnabledPad11Types>;

// only float type is supported for opset-10
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Pad,
2, 10,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()}),
KernelDefBuilder().TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<Pad2Types>(),
BuildKernelDefConstraintsFromTypeList<EnabledPad2Types>()),
Pad);

// The interface for the 'Pad' op was changed in opset-11
@@ -48,27 +87,19 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Pad,
11, 12,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>()}),
KernelDefBuilder().TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<Pad11Types>(),
BuildKernelDefConstraintsFromTypeList<EnabledPad11Types>()),
Pad);

ONNX_CPU_OPERATOR_KERNEL(
Pad,
13,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>()}),
KernelDefBuilder().TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<Pad11Types>(),
BuildKernelDefConstraintsFromTypeList<EnabledPad11Types>()),
Pad);

// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values)
@@ -419,7 +450,6 @@ Status Pad::Compute(OpKernelContext* ctx) const {
const std::vector<int64_t>* pads_to_use;
const std::vector<int64_t>* slices_to_use;
PadValue value;
Status status;

// kOnnxDomain Pad opset >= 11 (Or) kMsDomain opset == 1
if (is_dynamic_) {
@@ -483,19 +513,32 @@ Status Pad::Compute(OpKernelContext* ctx) const {
pads_to_use = &pads_;
slices_to_use = &slices_;
}

optional<Status> pad_status{};
switch (element_size) {
case sizeof(uint32_t):
status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint32_t>()) {
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
}
break;
case sizeof(uint64_t):
status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint64_t>()) {
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
}
break;
case sizeof(uint8_t):
status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint8_t>()) {
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
}
break;
default:
ORT_THROW("Unsupported input data type of ", data_type);
break;
}

if (!pad_status) {
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
}
return status;

return *pad_status;
}
}; // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/op_kernel_type_control_utils.h
Original file line number Diff line number Diff line change
@@ -36,6 +36,12 @@ constexpr bool HasTypeWithSameSize() {
return boost::mp11::mp_set_contains<EnabledTypeSizes, SizeOfT<T>>::value;
}

/**
* The union of the given type sets.
*/
template <typename... TypeSets>
using TypeSetUnion = boost::mp11::mp_set_union<TypeSets...>;

} // namespace utils
} // namespace onnxruntime

0 comments on commit 15d81fb

Please sign in to comment.