Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the matmul op with Onednn to leverage AMX optimization. #413

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ cc_library(
"@hwy//:matvec",
"@hwy//:profiler",
"@hwy//:thread_pool",
"//third_party/intel_dnnl:dnnl",
"//third_party/tbb",
],
)

Expand Down
2 changes: 1 addition & 1 deletion gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct RuntimeConfig {
const ImageTokens *image_tokens = nullptr;

// Whether to use thread spinning to reduce barrier synchronization latency.
bool use_spinning = true;
bool use_spinning = false;

// End-of-sequence token.
int eos_id = EOS_ID;
Expand Down
163 changes: 158 additions & 5 deletions ops/matmul-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,20 @@
// After highway.h
#include "compression/compress-inl.h"
#include "hwy/contrib/math/math-inl.h"
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp"
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_common.hpp"
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_types.h"


HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
using namespace dnnl;
using tag = memory::format_tag;
using dt = memory::data_type;
using dnnl::primitive_attr;
using dnnl::reorder;

// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
// loads, we reuse the same A row for several B columns, which are also loaded
Expand Down Expand Up @@ -66,7 +75,9 @@ constexpr size_t kRegRows = kRegCols;
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
// to bf16 if the native op is available. This will actually demote f32
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
// Update the MulT, so Highway matmul always covert inputs to bf16, which is
// matched with the dnnl matmul logic.
using MulT = BF16;

// Loads two vectors at a time with element type MulT from a row of transposed
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
Expand Down Expand Up @@ -450,10 +461,10 @@ HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A,
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
// Must not be called concurrently with the same `env`.
template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) {
HWY_NOINLINE void MatMul_hwy(const size_t batch_size, const Mat<const MatTA>& A,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) {
// PROFILER_ZONE("Matmul");
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
HWY_DASSERT(A.cols == B.cols);
Expand Down Expand Up @@ -499,6 +510,148 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
});
}

template <typename T>
static memory::data_type DnnType();

/// Instantiation for float type. Add similar instantiations for other
/// type if needed.
template <>
memory::data_type DnnType<float>() {
return memory::data_type::f32;
}

template <>
memory::data_type DnnType<BF16>() {
return memory::data_type::bf16;
}

template <>
memory::data_type DnnType<SfpStream>() {
fprintf(stderr, "DnnType SfpStream is not supported\n");
return memory::data_type::bf16;
}
template <typename MatT>
memory convert_to_bf16(dnnl::engine engine, dnnl::stream engine_stream,
const Mat<const MatT>& source, memory::dims dims,
tag format_tag) {
// Write the input matrix to dnnl memory.
dnnl::memory::desc src_md(dims, DnnType<MatT>(), format_tag);
auto source_mem = memory(src_md, engine);
source_mem.set_data_handle(const_cast<MatT*>(source.ptr + source.ofs));
if (std::is_same<MatT, BF16>::value) {
return source_mem;
}
// When the input is float, convert it to BF16.
if (std::is_same<MatT, float>::value) {
auto dst_md = memory::desc(source_mem.get_desc().get_dims(),
dnnl::memory::data_type::bf16, format_tag);
dnnl::memory dst_mem(dst_md, engine);
auto reorder_pd =
reorder::primitive_desc(engine, source_mem.get_desc(), engine, dst_md);
auto reorder_prim = reorder(reorder_pd);
reorder_prim.execute(engine_stream,
{{DNNL_ARG_FROM, source_mem}, {DNNL_ARG_TO, dst_mem}});
return dst_mem;
}
fprintf(stderr, "Unsupported type\n");
return source_mem;
}

template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul_dnnl(const size_t batch_size,
const Mat<const MatTA>& A,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) {
dnnl::engine engine = env.engine;
dnnl::stream engine_stream = env.engine_stream;

// First stage: process the input data.
// OneDNN mandates that input and output data be managed using the
// dnnl::memory, a practice that can lead to enhanced performance.
// Create memory dims for inputs and outputs.
const memory::dim kRowsAC = batch_size, kColsARowsB = A.cols,
kColsBC = C.cols;
memory::dims src_dims = {kRowsAC, kColsARowsB};
memory::dims weights_dims = {kColsARowsB, kColsBC};
memory::dims bias_dims = {1, kColsBC};
memory::dims dst_dims = {kRowsAC, kColsBC};

auto src_format_tag = tag::ab;
// `B` is a transposed matrix.
auto weights_format_tag = tag::ba;
auto dest_format_tag = tag::ab;

// Create memory descriptors for inputs and outputs.
auto src_md = memory::desc(src_dims, DnnType<MulT>(), src_format_tag);
auto weights_md =
memory::desc(weights_dims, DnnType<MulT>(), weights_format_tag);
auto scale_md = memory::desc({{1}, dt::f32, tag::x});
auto dst_md = memory::desc(dst_dims, dt::f32, dest_format_tag);

auto src_mem =
convert_to_bf16(engine, engine_stream, A, src_dims, src_format_tag);
auto weights_mem = convert_to_bf16(engine, engine_stream, B, weights_dims,
weights_format_tag);
auto dst_mem = memory(dst_md, engine);

// Second stage: Create the matmul primitive/operation.
// Define matmul Primitive arguments.
std::unordered_map<int, memory> matmul_args;
matmul_args.insert({DNNL_ARG_SRC, src_mem});
matmul_args.insert({DNNL_ARG_WEIGHTS, weights_mem});
matmul_args.insert({DNNL_ARG_DST, dst_mem});

// Apply the scaling factor to the weights, enabling us to apply it
// to the multiplication results before incorporating the bias.
auto scale_mem = memory(scale_md, engine, const_cast<float*>(&scale));
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_mem});
dnnl::primitive_attr attr;
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);

// Create primitive descriptor.
matmul::primitive_desc matmul_pd;

// When there is bias, add it to the matmul_args.
if (kAdd) {
auto bias_md = memory::desc(bias_dims, dt::f32, tag::ab);
auto bias_mem = memory(bias_md, engine);
bias_mem.set_data_handle(const_cast<float*>(add));
matmul_args.insert({DNNL_ARG_BIAS, bias_mem});
matmul_pd = matmul::primitive_desc(engine, src_md, weights_md, bias_md,
dst_md, attr);
} else {
matmul_pd =
matmul::primitive_desc(engine, src_md, weights_md, dst_md, attr);
}

// Third stage: Execute the matmul primitive/operation.
auto matmul_prim = matmul(matmul_pd);
matmul_prim.execute(engine_stream, matmul_args);
engine_stream.wait();

// Copy the output from dnnl memory to the output matrix.
// Adding padding when the C.stride is more than the C.cols.
auto c_mem_ptr = static_cast<float*>(dst_mem.get_data_handle());
const hn::ScalableTag<float> df;
for (int row = 0; row < batch_size; ++row) {
hn::StoreU(hn::Zero(df), df, C.ptr + row * C.stride);
std::copy(c_mem_ptr + row * C.cols, c_mem_ptr + (row + 1) * C.cols,
C.ptr + row * C.stride);
}
}

template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) {
MatMul_dnnl<kAdd>(batch_size, A, B, scale, add, env, C);

// Enable the hwy matmul and disable the dnnl matmul, when we need to
// benchmark the hwy matmul.
// MatMul_hwy<kAdd>(batch_size, A, B, scale, add, env, C);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
Expand Down
17 changes: 16 additions & 1 deletion ops/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@

#include <stddef.h>

#include "tbb/global_control.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h" // PerClusterPools
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/per_target.h"
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp"

using namespace dnnl;
using namespace tbb;

namespace gcpp {

Expand Down Expand Up @@ -81,8 +86,18 @@ class MatMulEnv {
const size_t num_lp = pools.NumLP();
const size_t NF = hwy::VectorBytes() / sizeof(float);
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
setenv("ONEDNN_MAX_CPU_ISA", "AVX512_CORE_AMX", 1);
// Enable verbose logging for dnnl when we need to debug.
// setenv("DNNL_VERBOSE", "2", 2);
tbb::global_control global_limit(
tbb::global_control::max_allowed_parallelism, 128);
// Create execution dnnl::engine.
engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
// Create dnnl::stream.
engine_stream = dnnl::stream(engine);
}

dnnl::stream engine_stream;
dnnl::engine engine;
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); }
PerClusterPools& Pools() const { return *pools_; }
hwy::ThreadPool& Pool() const { return pools_->Inner(0); }
Expand Down
34 changes: 15 additions & 19 deletions ops/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "compression/shared.h"
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
#define HWY_DISABLED_TARGETS HWY_SCALAR
Expand Down Expand Up @@ -145,7 +146,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
// Dot(float,BF16) rounds both to BF16.
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
using RefType = BF16;
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
const double tolerance = 200.0 * norm * epsilon;

Expand Down Expand Up @@ -233,8 +234,13 @@ void TestMatMul(MatMulEnv& env) {
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
const double start_slow = hwy::platform::Now();
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
kAdd ? add->data() : nullptr, env, c_slow->data());
// Compare the dnnl matmul results with the hwy matmul results.
MatMul_hwy<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
ConstMat(b_trans->data(), kColsARowsB), scale,
kAdd ? add->data_scale1() : nullptr, env,
MutableMat(c_slow->data(), kColsBC));
// MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
// kAdd ? add->data() : nullptr, c_slow->data());
if (want_bench) {
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
hwy::platform::Now() - start_slow);
Expand All @@ -258,9 +264,9 @@ void TestMatMul(MatMulEnv& env) {
}

void TestAllMatMul() {
tbb::global_control global_limit(tbb::global_control::max_allowed_parallelism, 128);
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
if (HWY_TARGET != HWY_AVX3 && HWY_TARGET != HWY_AVX3_SPR) {
return;
}

Expand All @@ -272,10 +278,10 @@ void TestAllMatMul() {
using SFP = SfpStream;

// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<128, 3072, 24576, /*kAdd=*/false, BF16>(env);
TestMatMul<1, 24576, 3072, /*kAdd=*/false, BF16>(env);
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, BF16>(env);

// medium-sized square test - temporarily disabled for faster testing.
if constexpr (false) {
Expand All @@ -292,32 +298,22 @@ void TestAllMatMul() {
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
Loading