Skip to content

Commit

Permalink
Support a CPU kernel for Celu (microsoft#6995)
Browse files Browse the repository at this point in the history
hariharans29 authored Mar 15, 2021
1 parent d0cca35 commit 27ac882
Showing 4 changed files with 35 additions and 3 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/activation/activations.cc
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1);
REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1);
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, float);
REGISTER_VERSIONED_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 6, 12, double);
REGISTER_UNARY_ELEMENTWISE_KERNEL(Celu, 12);
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, float);
REGISTER_UNARY_ELEMENTWISE_TYPED_KERNEL(Tanh, 13, double);
REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10);
@@ -64,6 +65,7 @@ namespace functors {
template <typename T>
Status ElementWiseRangedTransform<T>::Create(const std::string& type, const NodeAttributes& attributes,
std::unique_ptr<ElementWiseRangedTransform<T>>& out) {
CREATE_ELE_KERNEL(Celu);
CREATE_ELE_KERNEL(Elu);
CREATE_ELE_KERNEL(HardSigmoid);
CREATE_ELE_KERNEL(LeakyRelu);
22 changes: 20 additions & 2 deletions onnxruntime/core/providers/cpu/activation/activations.h
Original file line number Diff line number Diff line change
@@ -13,6 +13,23 @@ namespace onnxruntime {

namespace functors {

template <typename T>
struct Celu : public ElementWiseRangedTransform<T> {
ORT_GET_FLOAT_ATTR_AND_RETURN(alpha);

float Cost() const final {
// TODO: Tune the cost
return 1.0f;
}
void operator()(std::ptrdiff_t first, std::ptrdiff_t last) const final {
ptrdiff_t len = last - first;
T* output_ptr = this->output + first;
ConstEigenVectorArrayMap<T> xm(this->input + first, len);
EigenVectorArrayMap<T> ym(output_ptr, len);
ym = xm.cwiseMax(0.0f) + (((T)alpha * ((xm / (T)alpha).exp() - 1)).cwiseMin(0.0f));
}
};

template <typename T>
struct Elu : public ElementWiseRangedTransform<T> {
ORT_GET_FLOAT_ATTR_AND_RETURN(alpha);
@@ -89,9 +106,9 @@ struct Relu : public ElementWiseRangedTransform<T> {
Status Init(const onnxruntime::NodeAttributes&) {
return Status::OK();
}
ElementWiseRangedTransform<T>* Copy() const { // replace it with a macro. why this?
ElementWiseRangedTransform<T>* Copy() const { // replace it with a macro. why this?
using T1 = typename std::remove_pointer<decltype(this)>::type;
using T2 = typename std::remove_const<T1>::type; //redundant?
using T2 = typename std::remove_const<T1>::type; //redundant?
return new T2(*this);
}
float Cost() const final {
@@ -212,6 +229,7 @@ struct Selu : public ElementWiseRangedTransform<T> {

} // namespace functors

DEFINE_ELE_KERNEL(Celu);
DEFINE_ELE_KERNEL(Elu);
DEFINE_ELE_KERNEL(HardSigmoid);
DEFINE_ELE_KERNEL(LeakyRelu);
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
@@ -465,6 +465,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu);

// opset 13
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf);
@@ -1093,7 +1094,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
float, Gemm)>,
float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float,
@@ -1427,6 +1428,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu)>,

// opset 13
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Cast)>,
10 changes: 10 additions & 0 deletions onnxruntime/test/providers/cpu/activation/activation_op_test.cc
Original file line number Diff line number Diff line change
@@ -61,6 +61,16 @@ TEST_F(ActivationOpTest, Elu) {
{{"alpha", alpha}});
}

TEST_F(ActivationOpTest, Celu) {
float alpha = -0.5f;
TestActivationOp<float>(
"Celu",
input_values,
// TODO: Investigate why gcc 4 fails to compile without the explicit cast
[alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast<float>(exp(x / alpha)) - 1)); },
// Disable on TensorRT as it seems like it doesn't yet support Celu
{{"alpha", alpha}}, false, 12);
}
TEST_F(ActivationOpTest, LeakyRelu) {
float alpha = 0.1f;
TestActivationOp<float>("LeakyRelu",

0 comments on commit 27ac882

Please sign in to comment.