Skip to content

Commit

Permalink
put all gemmlowp common code in one place (microsoft#1590)
Browse files Browse the repository at this point in the history
* put all gemmlowp common code in one place

* fix gpu build failures

* minor update
  • Loading branch information
askhade authored Aug 11, 2019
1 parent 59c9d83 commit 7be40b2
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 278 deletions.
11 changes: 0 additions & 11 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,6 @@ if(HAS_DEPRECATED_COPY)
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/where_op.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy)
endif()

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT MSVC)
# For x86 platforms it is important to pass this flag to compiler. Without this gemmlowp will use slow reference code.
# These optimizations are not enabled on MSVC so excluding it.
message("enabling optimizations for gemmlowp")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/matmul_integer.cc" PROPERTIES COMPILE_FLAGS "-msse4.1")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/quantize_linear_matmul.cc" PROPERTIES COMPILE_FLAGS "-msse4.1")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/qlinearconv.cc" PROPERTIES COMPILE_FLAGS "-msse4.1")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/conv_integer.cc" PROPERTIES COMPILE_FLAGS "-msse4.1")
endif()

set(gemmlowp_src ${PROJECT_SOURCE_DIR}/external/gemmlowp)
set(re2_src ${ONNXRUNTIME_ROOT}/../cmake/external/re2)
target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${gemmlowp_src} ${re2_src})
add_dependencies(onnxruntime_providers gsl onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
Expand Down
11 changes: 10 additions & 1 deletion cmake/onnxruntime_util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@ file(GLOB_RECURSE onnxruntime_util_srcs CONFIGURE_DEPENDS

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_util_srcs})

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT MSVC)
# For x86 platforms it is important to pass this flag to compiler. Without this gemmlowp will use slow reference code.
# These optimizations are not enabled on MSVC so excluding it.
message("enabling optimizations for gemmlowp")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/util/gemmlowp_common.cc" PROPERTIES COMPILE_FLAGS "-msse4.1")
endif()

set(gemmlowp_src ${PROJECT_SOURCE_DIR}/external/gemmlowp)

add_library(onnxruntime_util ${onnxruntime_util_srcs})
target_include_directories(onnxruntime_util PRIVATE ${ONNXRUNTIME_ROOT} ${MKLML_INCLUDE_DIR} PUBLIC ${eigen_INCLUDE_DIRS})
target_include_directories(onnxruntime_util PRIVATE ${ONNXRUNTIME_ROOT} ${MKLML_INCLUDE_DIR} ${gemmlowp_src} PUBLIC ${eigen_INCLUDE_DIRS})
onnxruntime_add_include_to_target(onnxruntime_util onnxruntime_common onnxruntime_framework gsl onnx onnx_proto protobuf::libprotobuf)
if(UNIX)
target_compile_options(onnxruntime_util PUBLIC "-Wno-error=comment")
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "core/common/common.h"
#include "core/framework/tensor.h"

namespace onnxruntime {

Expand All @@ -20,4 +21,16 @@ inline int64_t HandleNegativeAxis(int64_t axis, int64_t tensor_rank) {
return axis = axis < 0 ? axis + tensor_rank : axis;
}

/**
Returns true if given tensor is a scalar or 1D tensor of size 1
**/
inline bool IsScalarOr1ElementVector(const Tensor* input) {
if (input->Shape().NumDimensions() == 0 ||
(input->Shape().NumDimensions() == 1 && input->Shape().GetDims().size() == 1)) {
return true;
} else {
return false;
}
}

} // namespace onnxruntime
54 changes: 15 additions & 39 deletions onnxruntime/core/providers/cpu/math/matmul_integer.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _MSC_VER
#pragma warning(disable : 4244)
#pragma warning(disable : 4267)
#endif

#include "core/providers/cpu/math/matmul_integer.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/gemmlowp_common_wrapper.h"
#include "core/util/gemmlowp_common.h"
#include "core/providers/common.h"

namespace onnxruntime {

Expand All @@ -24,25 +20,7 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
MatMulInteger<uint8_t, uint8_t, int32_t>);

Status GemmlowpMultiply(const uint8_t* lhs_data, const uint8_t* rhs_data,
int32_t* result_data, const int lhs_offset, const int rhs_offset,
int m, int n, int k) {
const std::tuple<> empty_pipeline = {};
// TODO exp ColMajor order for rhs and result. That may be faster
const auto matOrder = gemmlowp::MapOrder::RowMajor;
gemmlowp::MatrixMap<const std::uint8_t, matOrder> lhs(lhs_data, m, k);
gemmlowp::MatrixMap<const std::uint8_t, matOrder> rhs(rhs_data, k, n);
gemmlowp::MatrixMap<std::int32_t, matOrder> result(result_data, m, n);

gemmlowp::GemmContext gemm_context;
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline);

return Status::OK();
}

template<>
template <>
Status MatMulInteger<uint8_t, uint8_t, int32_t>::Compute(OpKernelContext* ctx) const {
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(1);
Expand All @@ -57,28 +35,26 @@ Status MatMulInteger<uint8_t, uint8_t, int32_t>::Compute(OpKernelContext* ctx) c
int32_t b_offset = 0;
if (has_a_zero_point_) {
auto a_zero_point = ctx->Input<Tensor>(2);
ORT_ENFORCE(a_zero_point->Shape().NumDimensions() == 0 ||
(a_zero_point->Shape().NumDimensions() == 1 && a_zero_point->Shape().GetDims().size() == 1),
"Currently only scalar zero_point is supported. TODO: add per channel zero point support.");
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
a_offset = static_cast<int32_t>(*a_zero_point->template Data<uint8_t>());
}
if (has_b_zero_point_) {
auto b_zero_point = ctx->Input<Tensor>(3);
ORT_ENFORCE(b_zero_point->Shape().NumDimensions() == 0 ||
(b_zero_point->Shape().NumDimensions() == 1 && b_zero_point->Shape().GetDims().size() == 1),
"Currently only scalar zero_point is supported. TODO: add per channel zero point support.");
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point),
"MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1");
b_offset = static_cast<int32_t>(*b_zero_point->template Data<uint8_t>());
}

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
GemmlowpMultiply(a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<uint8_t>() + helper.RightOffsets()[i],
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
a_offset,
b_offset,
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
GemmlowpMultiplyu8u8_s32(a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<uint8_t>() + helper.RightOffsets()[i],
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
a_offset,
b_offset,
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
}

return Status::OK();
Expand Down
104 changes: 31 additions & 73 deletions onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _MSC_VER
#pragma warning(disable : 4244)
#pragma warning(disable : 4267)
#endif

#include "core/providers/cpu/math/quantize_linear_matmul.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/gemmlowp_common_wrapper.h"
#include "core/providers/common.h"

namespace onnxruntime {

Expand All @@ -24,55 +19,7 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<uint8_t>()),
QLinearMatMul<uint8_t, uint8_t, uint8_t>);

Status GemmlowpMultiply(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
const int lhs_offset, const int rhs_offset, const int result_offset,
int m, int n, int k, int32_t int_multiplier, int32_t right_shift) {
gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage;
quantize_down_stage.result_offset_after_shift = result_offset;
quantize_down_stage.result_fixedpoint_multiplier = int_multiplier;
quantize_down_stage.result_shift = right_shift;
gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
const auto& output_pipeline = std::make_tuple(quantize_down_stage, saturating_cast_stage);

// TODO exp ColMajor order for rhs and result. That may be faster
const auto matOrder = gemmlowp::MapOrder::RowMajor;
gemmlowp::MatrixMap<const std::uint8_t, matOrder> lhs(lhs_data, m, k);
gemmlowp::MatrixMap<const std::uint8_t, matOrder> rhs(rhs_data, k, n);
gemmlowp::MatrixMap<std::uint8_t, matOrder> result(result_data, m, n);

gemmlowp::GemmContext gemm_context;
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::uint8_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, output_pipeline);

return Status::OK();
}

void QuantizeMultiplier(float fp_multiplier, std::int32_t* integer_multiplier, int* right_shift) {
auto* fp_as_bits = reinterpret_cast<uint32_t*>(&fp_multiplier);
auto current_exponent = (*fp_as_bits >> 23);
// bring multiplier in [.5,1) range and calculate the shift
auto bumped_multiplier_as_bits =
(*fp_as_bits & UINT32_C(0x007fffff)) | UINT32_C(0x3f000000);
auto* bumped_multiplier = reinterpret_cast<float*>(&bumped_multiplier_as_bits);
auto shift = 126 - current_exponent;
// convert to fixed point number
auto int_multiplier = static_cast<std::int64_t>(std::round(*bumped_multiplier * (1ll << 31)));

*integer_multiplier = static_cast<int32_t>(int_multiplier);
*right_shift = shift;
}

void ScaleAndZeropointPairValidationHelper(const Tensor* scale, const Tensor* zeropoint) {
ORT_ENFORCE(scale->Shape().NumDimensions() == 0 ||
(scale->Shape().NumDimensions() == 1 && scale->Shape().GetDims().size() == 1),
"scale must be a scalar");
ORT_ENFORCE(zeropoint->Shape().NumDimensions() == 0 ||
(zeropoint->Shape().NumDimensions() == 1 && zeropoint->Shape().GetDims().size() == 1),
"zeropoint must be a scalar");
}

template<>
template <>
Status QLinearMatMul<uint8_t, uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(3);
Expand All @@ -82,16 +29,27 @@ Status QLinearMatMul<uint8_t, uint8_t, uint8_t>::Compute(OpKernelContext* ctx) c
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());

// validate scale and zero points
// validate offsets
auto a_offset = ctx->Input<Tensor>(2);
auto b_offset = ctx->Input<Tensor>(5);
auto y_offset = ctx->Input<Tensor>(7);
ORT_ENFORCE(IsScalarOr1ElementVector(a_offset),
"QLinearMatmul : input zero point must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(b_offset),
"QLinearMatmul : weight zero point must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(y_offset),
"QLinearMatmul : result zero point must be a scalar or 1D tensor of size 1");

// validate scale
auto a_scale = ctx->Input<Tensor>(1);
auto a_zero_point = ctx->Input<Tensor>(2);
ScaleAndZeropointPairValidationHelper(a_scale, a_zero_point);
auto b_scale = ctx->Input<Tensor>(4);
auto b_zero_point = ctx->Input<Tensor>(5);
ScaleAndZeropointPairValidationHelper(b_scale, b_zero_point);
auto y_scale = ctx->Input<Tensor>(6);
auto y_zero_point = ctx->Input<Tensor>(7);
ScaleAndZeropointPairValidationHelper(y_scale, y_zero_point);
ORT_ENFORCE(IsScalarOr1ElementVector(a_scale),
"QLinearMatmul : input scale must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(b_scale),
"QLinearMatmul : weight scale must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(y_scale),
"QLinearMatmul : result scale must be a scalar or 1D tensor of size 1");

auto a_scale_data = *(a_scale->template Data<float>());
auto b_scale_data = *(b_scale->template Data<float>());
Expand All @@ -103,17 +61,17 @@ Status QLinearMatMul<uint8_t, uint8_t, uint8_t>::Compute(OpKernelContext* ctx) c
QuantizeMultiplier(real_multiplier, &integer_multiplier, &right_shift);

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
GemmlowpMultiply(a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<uint8_t>() + helper.RightOffsets()[i],
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
*a_zero_point->template Data<uint8_t>(),
*b_zero_point->template Data<uint8_t>(),
*y_zero_point->template Data<uint8_t>(),
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
integer_multiplier,
right_shift);
GemmlowpMultiplyu8u8_u8(a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<uint8_t>() + helper.RightOffsets()[i],
y->template MutableData<uint8_t>() + helper.OutputOffsets()[i],
*a_offset->template Data<uint8_t>(),
*b_offset->template Data<uint8_t>(),
*y_offset->template Data<uint8_t>(),
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
integer_multiplier,
right_shift);
}

return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"
#include "core/util/gemmlowp_common.h"

namespace onnxruntime {

Expand All @@ -16,6 +17,6 @@ class QLinearMatMul final : public OpKernel {
}

Status Compute(OpKernelContext* context) const override;

};
} // namespace onnxruntime
61 changes: 19 additions & 42 deletions onnxruntime/core/providers/cpu/nn/conv_integer.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _MSC_VER
#pragma warning(disable : 4244)
#pragma warning(disable : 4267)
#endif

#include "core/providers/cpu/nn/conv_integer.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/util/gemmlowp_common_wrapper.h"
#include "core/util/gemmlowp_common.h"
#include "core/providers/common.h"

namespace onnxruntime {

Expand All @@ -28,27 +25,17 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
int32_t input_offset = 0;
int32_t filter_offset = 0;
uint8_t input_offset = 0;
uint8_t filter_offset = 0;
if (num_inputs >= 3) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
if (X_Zero_Point->Shape().NumDimensions() == 0 ||
(X_Zero_Point->Shape().NumDimensions() == 1 && X_Zero_Point->Shape().GetDims().size() == 1)) {
input_offset = static_cast<int32_t>(*(X_Zero_Point->Data<uint8_t>()));
} else {
//TODO: Add support for per-channel quantization.
return Status(common::ONNXRUNTIME, common::FAIL, "Non per-tensor quantization is not supported now.");
}
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
input_offset = *(X_Zero_Point->Data<uint8_t>());
}
if (num_inputs >= 4) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
if (W_Zero_Point->Shape().NumDimensions() == 0 ||
(W_Zero_Point->Shape().NumDimensions() == 1 && W_Zero_Point->Shape().GetDims().size() == 1)) {
filter_offset = static_cast<int32_t>(*(W_Zero_Point->Data<uint8_t>()));
} else {
//TODO: Add support for per-channel quantization.
return Status(common::ONNXRUNTIME, common::FAIL, "Non per-tensor quantization is not supported now.");
}
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *(W_Zero_Point->Data<uint8_t>());
}

const int64_t N = X->Shape()[0];
Expand Down Expand Up @@ -118,27 +105,17 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance(),
false,
input_offset);

const uint8_t* filter_data_as_uint8 = W->template Data<uint8_t>() + group_id * W_offset;
static const gemmlowp::MapOrder ResultOrder = gemmlowp::MapOrder::RowMajor;
static const gemmlowp::MapOrder LhsOrder = gemmlowp::MapOrder::RowMajor;
static const gemmlowp::MapOrder RhsOrder = gemmlowp::MapOrder::RowMajor;
gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
filter_data_as_uint8, static_cast<int>(M / group_), static_cast<int>(kernel_dim));
gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
col_buffer_data, static_cast<int>(kernel_dim), static_cast<int>(output_image_size));
gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
Ydata + group_id * Y_offset, static_cast<int>(M / group_), static_cast<int>(output_image_size));
const std::tuple<> empty_pipeline = {};

gemmlowp::GemmContext gemm_context;
// TODO: worker thread pool needs to be handled.
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, lhs, rhs, &result, -filter_offset, -input_offset,
empty_pipeline);
false,
input_offset);

GemmlowpMultiplyu8u8_s32(W->template Data<uint8_t>() + group_id * W_offset,
col_buffer_data,
Ydata + group_id * Y_offset,
filter_offset,
input_offset,
static_cast<int>(M / group_),
static_cast<int>(output_image_size),
static_cast<int>(kernel_dim));
}

Xdata += X_offset * group_;
Expand Down
Loading

0 comments on commit 7be40b2

Please sign in to comment.