Skip to content

Commit

Permalink
add s8s8 support for quantized conv and gemm (microsoft#9902)
Browse files Browse the repository at this point in the history
* add s8s8 support for quantized conv and gemm
  • Loading branch information
yufenglee authored Dec 3, 2021
1 parent d8c7130 commit e613019
Show file tree
Hide file tree
Showing 50 changed files with 4,713 additions and 652 deletions.
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function(setup_mlas_source_for_windows)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
)

set(mlas_platform_preprocess_srcs
Expand All @@ -51,6 +52,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm
${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm
${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm
${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm
)
Expand Down Expand Up @@ -271,10 +273,12 @@ else()
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
)
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
#if !defined(DISABLE_SPARSE_TENSORS)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SparseToDenseMatMul);
Expand Down Expand Up @@ -72,8 +72,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeLSTM);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv);
#if defined(MLAS_TARGET_ARM_ANY)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t_int8_t, QLinearConv);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool);
#if defined(MLAS_TARGET_ARM_ANY)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t_int8_t, QLinearConv);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
// ******** End: Quantization ******************* //
Expand Down Expand Up @@ -161,6 +167,9 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv)>,
#if defined(MLAS_TARGET_ARM_ANY)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t_int8_t, QLinearConv)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization)>,
Expand Down Expand Up @@ -198,7 +207,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SparseToDenseMatMul)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad)>,
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
const auto* weights_data = static_cast<const uint8_t*>(weights.DataRaw());
weights_is_signed_ = weights.IsDataType<int8_t>();

packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, weights_is_signed_);
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, false /*AIsSigned*/, weights_is_signed_);
if (packed_weights_size_ == 0) {
return Status::OK();
}
Expand All @@ -99,7 +99,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));

for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data);
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, false /*AIsSigned*/, weights_is_signed_, packed_weights_data);
packed_weights_data += packed_weights_size_;
weights_data += head_size;
}
Expand Down Expand Up @@ -227,13 +227,13 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
const auto* weights_data = packed_weights_ ? nullptr : static_cast<const uint8_t*>(weights->DataRaw());
const bool weights_is_signed = packed_weights_ ? weights_is_signed_ : weights->IsDataType<int8_t>();

MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = sequence_length;
gemm_shape.N = head_size;
gemm_shape.K = input_hidden_size;
gemm_shape.BIsSigned = weights_is_signed;

std::vector<MLAS_GEMM_U8X8_DATA_PARAMS> gemm_data_vec(loop_len);
std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(loop_len);
std::vector<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_procs;
scale_bias_procs.reserve(loop_len);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights&
}

is_weight_signed = weights.IsDataType<int8_t>();
const size_t packed_weights_size = MlasGemmPackBSize(N, K, is_weight_signed);
const size_t packed_weights_size = MlasGemmPackBSize(N, K, false /*AIsSigned*/, is_weight_signed);
if (packed_weights_size == 0) {
return Status::OK();
}
Expand All @@ -73,7 +73,7 @@ Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights&

const auto* weights_data = static_cast<const uint8_t*>(weights.DataRaw());
for (int i = 0; i < num_directions_; i++) {
MlasGemmPackB(N, K, weights_data, N, is_weight_signed, packed_weights_data);
MlasGemmPackB(N, K, weights_data, N, false /*AIsSigned*/, is_weight_signed, packed_weights_data);
packed_weights_data = static_cast<uint8_t*>(packed_weights_data) + packed_weights_size;
weights_data += N * K;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
}

// batch gemm
MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(helper.M());
gemm_shape.N = static_cast<size_t>(helper.N());
gemm_shape.K = static_cast<size_t>(helper.K());
Expand All @@ -122,7 +122,7 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
const size_t num_gemms = helper.OutputOffsets().size();
std::vector<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> gemm_scale_procs;
gemm_scale_procs.reserve(num_gemms);
std::vector<MLAS_GEMM_U8X8_DATA_PARAMS> gemm_data_vec(num_gemms);
std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(num_gemms);

for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) {
gemm_scale_procs.emplace_back(y_data + helper.OutputOffsets()[gemm_idx],
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
GemmBroadcastBias(M, N, 1.f, c->template Data<int32_t>(), &(c->Shape()), gemm_output_data);
}

MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape{M, N, K, b_is_signed, c != nullptr};
MLAS_GEMM_U8X8_DATA_PARAMS gemm_param;
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, false /*AIsSigned*/, b_is_signed, c != nullptr};
MLAS_GEMM_QUANT_DATA_PARAMS gemm_param;

gemm_param.A = a_data;
gemm_param.lda = gemm_shape.K;
Expand Down Expand Up @@ -177,17 +177,20 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
size_t out_lda,
const std::vector<float>& output_scales,
Tensor* y,
MLAS_GEMM_U8X8_DATA_PARAMS& gemm_param,
MLAS_GEMM_QUANT_DATA_PARAMS& gemm_param,
std::unique_ptr<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR>& scale_bias_proc_ptr,
std::unique_ptr<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR>& requant_proc_ptr) {
if (nullptr != y_zp) {
bool is_y_signed = y->IsDataType<int8_t>();
int32_t y_zero_point = is_y_signed ? *y_zp->template Data<int8_t>() : *y_zp->template Data<uint8_t>();
requant_proc_ptr = std::make_unique<MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR>(
static_cast<uint8_t*>(y->MutableDataRaw()),
y->MutableDataRaw(),
out_lda,
nullptr,
output_scales.data(),
output_scales.size() > 1,
*y_zp->template Data<uint8_t>());
y_zero_point,
is_y_signed);
gemm_param.OutputProcessor = requant_proc_ptr.get();
} else {
scale_bias_proc_ptr = std::make_unique<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR>(
Expand Down
96 changes: 67 additions & 29 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ Module Name:
#if defined(_M_ARM) || defined(__arm__)
#define MLAS_TARGET_ARM
#endif
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) || defined(MLAS_TARGET_ARM)
#define MLAS_TARGET_ARM_ANY
#endif

#if defined(__VSX__)
#define MLAS_TARGET_POWER
#endif
Expand Down Expand Up @@ -528,15 +532,23 @@ class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSO
MLAS_QUANTIZATION_GRANULARITY QuantGran_;
};

struct MLAS_GEMM_U8X8_SHAPE_PARAMS {
size_t M = 0;
size_t N = 0;
size_t K = 0;
bool BIsSigned = false;
bool IsAccumulateMode = false;
/**
* @brief Supply matrices shape and data type information to quantized gemm functions
*
** NOTE: AIsSigned == true is not supported on non-ARM devices for now.
** AIsSigned == true is supported on ARM devices when BIsSigned is also true.
*
*/
struct MLAS_GEMM_QUANT_SHAPE_PARAMS {
size_t M = 0; /**< Supplies the row size of matrix A */
size_t N = 0; /**< Supplies the column size of matrix B */
size_t K = 0; /**< Supplies the column size of matrix A and row size of matrix B */
bool AIsSigned = false; /**< Indicates whether type of A is int8_t or uint8_t.*/
bool BIsSigned = false; /**< Indicates whether type of B is int8_t or uint8_t */
bool IsAccumulateMode = false; /**< Indicates whether to accumulate to matrix C or override matrix C */
};

struct MLAS_GEMM_U8X8_DATA_PARAMS {
struct MLAS_GEMM_QUANT_DATA_PARAMS {
const uint8_t* A = nullptr;
size_t lda = 0;
uint8_t ZeroPointA = 0;
Expand All @@ -553,8 +565,8 @@ struct MLAS_GEMM_U8X8_DATA_PARAMS {
void
MLASCALL
MlasGemm(
const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape,
const MLAS_GEMM_U8X8_DATA_PARAMS& DataParams,
const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_QUANT_DATA_PARAMS& DataParams,
MLAS_THREADPOOL* ThreadPool
);

Expand All @@ -572,8 +584,8 @@ MlasGemm(
void
MLASCALL
MlasGemmBatch(
const MLAS_GEMM_U8X8_SHAPE_PARAMS& Shape,
const MLAS_GEMM_U8X8_DATA_PARAMS* DataParams,
const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_QUANT_DATA_PARAMS* DataParams,
const size_t BatchN,
MLAS_THREADPOOL* ThreadPool
);
Expand Down Expand Up @@ -605,6 +617,7 @@ MLASCALL
MlasGemmPackBSize(
size_t N,
size_t K,
bool AIsSigned,
bool BIsSigned
);

Expand All @@ -615,6 +628,7 @@ MlasGemmPackB(
size_t K,
const uint8_t* B,
size_t ldb,
bool AIsSigned,
bool BIsSigned,
void* PackedB
);
Expand Down Expand Up @@ -696,10 +710,11 @@ MlasConv(
void
MLASCALL
MlasConvDepthwise(
const uint8_t* const* Input,
uint8_t InputZeroPoint,
const uint8_t* Filter,
uint8_t FilterZeroPoint,
const void* const* Input,
int32_t InputZeroPoint,
bool InputIsSigned,
const void* Filter,
int32_t FilterZeroPoint,
bool FilterIsSigned,
int32_t* Output,
size_t Channels,
Expand All @@ -716,7 +731,8 @@ MlasConvSymPackWSize(
size_t GroupCount,
size_t InputChannels,
size_t OutputChannels,
size_t KernelSize
size_t KernelSize,
bool InputIsSigned
);

void
Expand All @@ -727,27 +743,30 @@ MlasConvSymPackW(
size_t KernelSize,
const int8_t* W,
int8_t* PackedW,
size_t PackedWSize
size_t PackedWSize,
bool InputIsSigned
);

int32_t
MlasConvSymFixupInputZeroPoint(
uint8_t zero_point_value
int32_t zero_point_value,
bool InputIsSigned
);

struct MLAS_CONV_SYM_PARAMS {
const uint8_t* InputDirect;
const uint8_t* const* InputIndirection;
const void* InputDirect;
const void* const* InputIndirection;
const void* Filter;
uint8_t* Output;
void* Output;
size_t InputChannels;
size_t OutputChannels;
size_t OutputCount;
size_t KernelSize;
const int32_t* Bias;
const float* Scale;
bool PerChannelScale;
uint8_t OutputZeroPoint;
int32_t OutputZeroPoint;
bool InputIsSigned;
};

void
Expand Down Expand Up @@ -870,6 +889,15 @@ MlasTranspose(
size_t N
);

void
MLASCALL
MlasTranspose(
const int8_t* Input,
int8_t* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
Expand Down Expand Up @@ -1064,18 +1092,20 @@ class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
{
public:
MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR(
uint8_t* Output,
void* Output,
size_t OutputLeadingDimension,
const int32_t* Bias,
const float* Scale,
bool PerColumnScale,
uint8_t ZeroPoint)
int32_t ZeroPoint,
bool OutputIsSigned)
: Output_(Output),
OutputLeadingDimension_(OutputLeadingDimension),
Bias_(Bias),
Scale_(Scale),
PerColumnScale_(PerColumnScale),
ZeroPoint_(ZeroPoint)
ZeroPoint_(ZeroPoint),
OutputIsSigned_(OutputIsSigned)
{
}

Expand All @@ -1086,18 +1116,26 @@ class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR
size_t CountN,
size_t ldc) const override
{
MlasRequantizeOutput(C, ldc, Output_, OutputLeadingDimension_, Bias_, Scale_,
PerColumnScale_, ZeroPoint_, StartM, StartN, CountM, CountN);
if(OutputIsSigned_){
MlasRequantizeOutput(C, ldc, reinterpret_cast<int8_t*>(Output_), OutputLeadingDimension_,
Bias_, Scale_, PerColumnScale_, static_cast<int8_t>(ZeroPoint_),
StartM, StartN, CountM, CountN);
} else {
MlasRequantizeOutput(C, ldc, reinterpret_cast<uint8_t*>(Output_), OutputLeadingDimension_,
Bias_, Scale_, PerColumnScale_, static_cast<uint8_t>(ZeroPoint_),
StartM, StartN, CountM, CountN);
}
}


private:
uint8_t* Output_;
void* Output_;
size_t OutputLeadingDimension_;
const int32_t* Bias_;
const float* Scale_;
bool PerColumnScale_;
uint8_t ZeroPoint_;
int32_t ZeroPoint_;
bool OutputIsSigned_;
};


Expand Down
Loading

0 comments on commit e613019

Please sign in to comment.