Skip to content

Commit

Permalink
Adding prepacking to QLinearMatMul (microsoft#6980)
Browse files Browse the repository at this point in the history
Reuse the same prepacking logic in mat mul integer, to enable prepacking weight for QLinearMatMul. Currently only prepacking 2D matrix weights
  • Loading branch information
chenfucn authored Mar 17, 2021
1 parent 90642e7 commit 03885af
Showing 7 changed files with 252 additions and 53 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
@@ -251,7 +251,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);

const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape());
const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights_shape,
bias->Shape(),
54 changes: 40 additions & 14 deletions onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@ class MatMulIntegerToFloatBase : public MatMulIntegerBase {
MatMulIntegerToFloatBase(const OpKernelInfo& info) : MatMulIntegerBase(info) {
}

protected:
enum OutputTensors : int { OUT_Y = 0 };

protected:
Status ComputeCommon(OpKernelContext* ctx,
const uint8_t* a_data,
const TensorShape& a_shape,
@@ -38,7 +40,7 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
const Tensor* bias_tensor) const {
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a_shape, packed_b_ ? b_shape_ : b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());
Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
@@ -86,25 +88,49 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
DynamicQuantizeMatMul(const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {}

Status Compute(OpKernelContext* context) const override;

enum InputTensors : int {
IN_A = 0,
IN_B = 1,
IN_B_SCALE = 2,
IN_B_ZERO_POINT = 3,
IN_BIAS = 4
};

protected:
int GetBIdx() override { return IN_B; }
};

class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase {
public:
MatMulIntegerToFloat(const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {}

Status Compute(OpKernelContext* context) const override;

enum InputTensors : int {
IN_A = 0,
IN_B = 1,
IN_A_SCALE = 2,
IN_B_SCALE = 3,
IN_A_ZERO_POINT = 4,
IN_B_ZERO_POINT = 5,
IN_BIAS = 6
};

protected:
int GetBIdx() override { return IN_B; }
};

Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
const Tensor* a = ctx->Input<Tensor>(IN_A);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(IN_B);

const Tensor* b_scale_tensor = ctx->Input<Tensor>(2);
const Tensor* b_scale_tensor = ctx->Input<Tensor>(IN_B_SCALE);
ORT_ENFORCE(IsScalarOr1ElementVector(b_scale_tensor),
"DynamicQuantizeMatMul : input B scale must be a scalar or 1D tensor of size 1. Per-Channel is not supported yet.");
float b_scale = *b_scale_tensor->template Data<float>();

const Tensor* b_zero_point_tensor = ctx->Input<Tensor>(3);
const Tensor* b_zero_point_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
uint8_t b_zero_point = 0;
if (b_zero_point_tensor != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point_tensor),
@@ -134,34 +160,34 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
b,
b_zero_point,
a_scale * b_scale,
ctx->Input<Tensor>(4));
ctx->Input<Tensor>(IN_BIAS));
}

Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
const Tensor* a = ctx->Input<Tensor>(IN_A);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(IN_B);

const Tensor* a_scale_tensor = ctx->Input<Tensor>(2);
const Tensor* a_scale_tensor = ctx->Input<Tensor>(IN_A_SCALE);
ORT_ENFORCE(IsScalarOr1ElementVector(a_scale_tensor),
"MatMulIntegerToFloat : input A scale must be a scalar or 1D tensor of size 1. Per-Channel is not supported yet.");
float a_scale = *a_scale_tensor->template Data<float>();

const Tensor* b_scale_tensor = ctx->Input<Tensor>(3);
const Tensor* b_scale_tensor = ctx->Input<Tensor>(IN_B_SCALE);
ORT_ENFORCE(IsScalarOr1ElementVector(b_scale_tensor),
"MatMulIntegerToFloat : input B scale must be a scalar or 1D tensor of size 1. Per-Channel is not supported yet.");
float b_scale = *b_scale_tensor->template Data<float>();

// validate zero points
uint8_t a_zero_point = 0;
const Tensor* a_zero_point_tensor = ctx->Input<Tensor>(4);
const Tensor* a_zero_point_tensor = ctx->Input<Tensor>(IN_A_ZERO_POINT);
if (a_zero_point_tensor != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point_tensor),
"MatMulIntegerToFloat : input A zero point must be a scalar or 1D tensor of size 1. Per-Channel is not supported yet.");
a_zero_point = *a_zero_point_tensor->Data<uint8_t>();
}

uint8_t b_zero_point = 0;
const Tensor* b_zero_point_tensor = ctx->Input<Tensor>(5);
const Tensor* b_zero_point_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
if (b_zero_point_tensor != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point_tensor),
"MatMulIntegerToFloat : input B zero point must be a scalar or 1D tensor of size 1. Per-Channel is not supported yet.");
@@ -175,7 +201,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
b,
b_zero_point,
a_scale * b_scale,
ctx->Input<Tensor>(6));
ctx->Input<Tensor>(IN_BIAS));
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
55 changes: 38 additions & 17 deletions onnxruntime/core/providers/cpu/math/matmul_integer.cc
Original file line number Diff line number Diff line change
@@ -14,6 +14,18 @@ class MatMulInteger final : public MatMulIntegerBase {
MatMulInteger(const OpKernelInfo& info) : MatMulIntegerBase(info) {}

Status Compute(OpKernelContext* context) const override;

enum InputTensors : int {
IN_A = 0,
IN_B = 1,
IN_A_ZERO_POINT = 2,
IN_B_ZERO_POINT = 3
};

enum OutputTensors : int { OUT_Y = 0 };

protected:
int GetBIdx() override { return IN_B; }
};

ONNX_OPERATOR_TYPED_KERNEL_EX(
@@ -29,27 +41,43 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulInteger);

Status MatMulInteger::Compute(OpKernelContext* ctx) const {
const auto* a = ctx->Input<Tensor>(0);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b ? b->Shape() : b_shape_));
Tensor* y = ctx->Output(0, helper.OutputShape());
const auto* a = ctx->Input<Tensor>(IN_A);

const uint8_t* b_data;
bool b_is_signed;
if (packed_b_) {
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_));
b_data = static_cast<const uint8_t*>(packed_b_.get());
b_is_signed = b_is_signed_;
} else {
const Tensor* b = ctx->Input<Tensor>(IN_B);
if (b == nullptr) {
// For required input, the framework has checks to ensure this won't happen.
// this is dead code to quiet compiler warning.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Required input B can not be null!");
}
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
b_data = static_cast<const uint8_t*>(b->DataRaw());
b_is_signed = b->IsDataType<int8_t>();
}

Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());
// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();

// validate zero points
uint8_t a_offset = 0;
uint8_t b_offset = 0;
const auto* a_zero_point = ctx->Input<Tensor>(2);
const auto* a_zero_point = ctx->Input<Tensor>(IN_A_ZERO_POINT);
if (a_zero_point != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
a_offset = *a_zero_point->template Data<uint8_t>();
}
const auto* b_zero_point = ctx->Input<Tensor>(3);
const auto* b_zero_point = ctx->Input<Tensor>(IN_B_ZERO_POINT);
if (b_zero_point != nullptr) {
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point),
"MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1");
@@ -65,22 +93,15 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
gemm_params.ldb = gemm_params.N;
gemm_params.ZeroPointB = &b_offset;
gemm_params.ldc = gemm_params.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.BIsSigned = b_is_signed;

const auto* a_data = a->template Data<uint8_t>();
auto* y_data = y->template MutableData<int32_t>();

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
gemm_params.A = a_data + helper.LeftOffsets()[i];
if (packed_b_) {
gemm_params.B = packed_b_.get();
gemm_params.BIsPacked = true;
gemm_params.BIsSigned = b_is_signed_;
} else if (b != nullptr) {
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + + helper.RightOffsets()[i];
gemm_params.BIsSigned = b->IsDataType<int8_t>();
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
}
gemm_params.B = b_data + helper.RightOffsets()[i];
gemm_params.C = y_data + helper.OutputOffsets()[i];
MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());
}
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cpu/math/matmul_integer_base.h
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ class MatMulIntegerBase : public OpKernel {
is_packed = false;

// only pack Matrix B
if (input_idx == 1) {
if (input_idx == GetBIdx()) {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
@@ -44,6 +44,11 @@ class MatMulIntegerBase : public OpKernel {
}

protected:
/**
* @return input index of Matrix B, the weight tensor
*/
virtual int GetBIdx() = 0;

bool b_is_signed_{true};
TensorShape b_shape_;
BufferUniquePtr packed_b_;
53 changes: 33 additions & 20 deletions onnxruntime/core/providers/cpu/math/quantize_linear_matmul.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "quantize_linear_matmul.h"

#include "core/framework/op_kernel.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/common/safeint.h"
@@ -11,13 +13,6 @@

namespace onnxruntime {

class QLinearMatMul final : public OpKernel {
public:
QLinearMatMul(const OpKernelInfo& info) : OpKernel(info) {}

Status Compute(OpKernelContext* context) const override;
};

ONNX_OPERATOR_KERNEL_EX(
QLinearMatMul,
kOnnxDomain,
@@ -30,21 +25,37 @@ ONNX_OPERATOR_KERNEL_EX(
QLinearMatMul);

Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
const auto* a = ctx->Input<Tensor>(0);
const auto* b = ctx->Input<Tensor>(3);

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());
const auto* a = ctx->Input<Tensor>(IN_A);

const uint8_t* b_data;
bool b_is_signed; // can't modify b_is_signed_, this is a const method
if (packed_b_) {
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_));
b_data = static_cast<const uint8_t*>(packed_b_.get());
b_is_signed = b_is_signed_;
} else {
const Tensor* b = ctx->Input<Tensor>(IN_B);
if (b == nullptr) {
// For required input, the framework has checks to ensure this won't happen,
// dead code to quiet the compiler warning.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Required input B can not be null!");
}
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
b_data = static_cast<const uint8_t*>(b->DataRaw());
b_is_signed = b->IsDataType<int8_t>();
}

Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());
// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();

// validate offsets
const auto* a_offset = ctx->Input<Tensor>(2);
const auto* b_offset = ctx->Input<Tensor>(5);
const auto* y_offset = ctx->Input<Tensor>(7);
const auto* a_offset = ctx->Input<Tensor>(IN_A_ZERO_POINT);
const auto* b_offset = ctx->Input<Tensor>(IN_B_ZERO_POINT);
const auto* y_offset = ctx->Input<Tensor>(IN_Y_ZERO_POINT);
ORT_ENFORCE(IsScalarOr1ElementVector(a_offset),
"QLinearMatmul : input zero point must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(b_offset),
@@ -53,9 +64,9 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
"QLinearMatmul : result zero point must be a scalar or 1D tensor of size 1");

// validate scale
const auto* a_scale = ctx->Input<Tensor>(1);
const auto* b_scale = ctx->Input<Tensor>(4);
const auto* y_scale = ctx->Input<Tensor>(6);
const auto* a_scale = ctx->Input<Tensor>(IN_A_SCALE);
const auto* b_scale = ctx->Input<Tensor>(IN_B_SCALE);
const auto* y_scale = ctx->Input<Tensor>(IN_Y_SCALE);
ORT_ENFORCE(IsScalarOr1ElementVector(a_scale),
"QLinearMatmul : input scale must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(b_scale),
@@ -84,13 +95,15 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
gemm_params.ZeroPointA = *a_offset->template Data<uint8_t>();
gemm_params.ldb = gemm_params.N;
gemm_params.ZeroPointB = static_cast<const uint8_t*>(b_offset->DataRaw());
gemm_params.BIsSigned = b->IsDataType<int8_t>();
gemm_params.C = gemm_output;
gemm_params.ldc = gemm_params.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.BIsSigned = b_is_signed;

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
gemm_params.A = a->template Data<uint8_t>() + helper.LeftOffsets()[i];
gemm_params.B = static_cast<const uint8_t*>(b->DataRaw()) + helper.RightOffsets()[i];
gemm_params.B = b_data + helper.RightOffsets()[i];

MlasGemm(&gemm_params, ctx->GetOperatorThreadPool());

MlasRequantizeOutput(gemm_output,
41 changes: 41 additions & 0 deletions onnxruntime/core/providers/cpu/math/quantize_linear_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.


#include "matmul_integer_base.h"

namespace onnxruntime {

// Allow subclassing for test only
class QLinearMatMul : public MatMulIntegerBase {
public:
QLinearMatMul(const OpKernelInfo& info) : MatMulIntegerBase(info) {}

Status Compute(OpKernelContext* context) const override;

/**
* @brief Give each input a name, should be consistent with doc spec in
* Operators.md
*/
enum InputTensors : int {
IN_A = 0,
IN_A_SCALE = 1,
IN_A_ZERO_POINT = 2,
IN_B = 3,
IN_B_SCALE = 4,
IN_B_ZERO_POINT = 5,
IN_Y_SCALE = 6,
IN_Y_ZERO_POINT = 7
};

enum OutputTensors : int {
OUT_Y = 0
};

protected:
int GetBIdx() override {
return IN_B;
}
};

} // namespace onnxruntime
Loading

0 comments on commit 03885af

Please sign in to comment.