Skip to content

Commit

Permalink
Zhanyao/attention (microsoft#10545)
Browse files Browse the repository at this point in the history
* Enable Attention op for ROCM EP.

As a note, potential hipify improvements: (1) handle math
contants (attention_softmax.h), (2) correctly generate transpose
options for the GEMM helpers, consider counterpart/dummy API for
CublasMathModeSetter (attention_impl.cu, attention_impl.cu). After
these improvements, we don't need to manually keep copies of the
above mentioned files any more.

* Clean up debugging code.
zhangyaobit authored Feb 17, 2022
1 parent 8d06e5a commit fd16085
Showing 6 changed files with 773 additions and 7 deletions.
126 changes: 126 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/rocm/bert/attention.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"

using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;

namespace onnxruntime {
namespace contrib {
namespace rocm {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Attention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info) {}

template <typename T>
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const Tensor* extra_add_qk = context->Input<Tensor>(5);

auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock));

// input shape (batch_size, sequence_length, input_hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int input_hidden_size = static_cast<int>(shape[2]);

// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;

int head_size = hidden_size / num_heads_;

TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);

int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);

rocblas_handle rocblas = RocblasHandle();
constexpr size_t element_size = sizeof(T);

// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);

typedef typename ToHipType<T>::MappedType HipT;
HipT one = ToHipType<T>::FromFloat(1.0f);
HipT zero = ToHipType<T>::FromFloat(0.0f);

// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// TODO: use custom kernel of expand to improve the performance.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->template Data<T>()), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_buffer.get()), n));

// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(weights->template Data<T>()), n,
reinterpret_cast<const HipT*>(input->template Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_buffer.get()), n));

size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length);
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
if (!LaunchAttentionKernel(
device_prop,
Stream(),
reinterpret_cast<const HipT*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
output->template MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
temp_buffer.get(),
rocblas,
element_size,
is_unidirectional_,
past_sequence_length,
nullptr == past ? nullptr : past->template Data<T>(),
nullptr == extra_add_qk ? nullptr : extra_add_qk->template Data<T>(),
nullptr == present ? nullptr : present->template MutableData<T>())) {
// Get last error to reset it to hipSuccess.
HIP_CALL(hipGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}

return Status::OK();
}

} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
213 changes: 213 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Modifications: scaling is moved from masked softmax to the gemm before that.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/attention_softmax.h"
#include "contrib_ops/rocm/bert/transformer_common.h"

using namespace onnxruntime::rocm;
using namespace hipcub;

#define CHECK_ROCM(expr) \
if (!HIP_CALL(expr)) { \
return false; \
}

namespace onnxruntime {
namespace contrib {
namespace rocm {

static size_t AlignTo(size_t a, size_t b) {
return CeilDiv(a, b) * b;
}

size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
const size_t len = batch_size * num_heads * sequence_length * all_sequence_length;
const size_t bytes = len * element_size;

const size_t alignment = 256;
const size_t bytesAligned = AlignTo(bytes, alignment);
return bytesAligned;
}

size_t GetAttentionWorkspaceSize(
size_t element_size,
int batch_size,
int num_heads,
int head_size,
int sequence_length,
int past_sequence_length) {
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
}

template <typename T>
bool QkvToContext(
const hipDeviceProp_t& prop, rocblas_handle& rocblas, hipStream_t stream,
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
const T* input, T* output, T* workspace,
const int* mask_index, gsl::span<const int64_t> mask_index_dims,
bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) {
const int all_sequence_length = past_sequence_length + sequence_length;
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
T* scratch1 = workspace;
T* scratch2 = scratch1 + (bytes / element_size);
T* scratch3 = scratch2 + (bytes / element_size);

const int max_threads_per_block = prop.maxThreadsPerBlock;

// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) {
return false;
}

// now scratch3 has Q, K, V: each has size BxNxSxH
const int batches = batch_size * num_heads;
const int size_per_batch = sequence_length * head_size;
const int total_size = batches * size_per_batch;

const T* q = scratch3;
const T* k = q + total_size;
const T* v = k + total_size;

rocblas_set_stream(rocblas, stream);

// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
const int present_size_per_batch = all_sequence_length * head_size;
if (nullptr != present) {
if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) {
return false;
}

// update pointers to present_k and present_v.
k = present;
v = present + batches * present_size_per_batch;
}

// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);

// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int temp_matrix_size = sequence_length * all_sequence_length;

typedef typename ToHipType<T>::MappedType HipT;

//float one = 1.0f;
//float zero = 0.f;
const HipT one = ToHipType<T>::FromFloat(1.0f);
const HipT zero = ToHipType<T>::FromFloat(0.f);

// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
//float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size;
const HipT alpha = use_raw_attention_mask ? one : ToHipType<T>::FromFloat(rsqrt_head_size);

if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
rocblas, rocblas_operation_transpose, rocblas_operation_none, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch,
q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) {
return false;
}

// apply softmax and store result P to scratch2: BxNxSxS*
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims.size());
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0;

T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2,
is_unidirectional, rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length),
use_persistent_softmax, persistent_softmax_workspace)) {
return false;
}
} else if (nullptr != mask_index) { // 1d mask index
ORT_ENFORCE(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr;
if (!ComputeSoftmaxWithMask1D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
return false;
}
} else { // no mask
if (!ComputeSoftmax<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
return false;
}
}

// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch,
scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches))) {
return false;
}

// scratch3 is BxNxSxH, transpose to output BxSxNxH
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output);
}

bool LaunchAttentionKernel(
const hipDeviceProp_t& prop,
hipStream_t stream,
const void* input,
const int* mask_index,
gsl::span<const int64_t> mask_index_dims,
void* output,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
void* workspace,
rocblas_handle& rocblas,
const size_t element_size,
bool is_unidirectional,
int past_sequence_length,
const void* past,
const void* extra_add_qk,
void* present) {
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
const TransformerOptions* options = TransformerOptions::GetInstance();
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const __half*>(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const __half*>(past), reinterpret_cast<const __half*>(extra_add_qk),
reinterpret_cast<__half*>(present), use_persistent_softmax);
} else {
return QkvToContext(prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input), reinterpret_cast<float*>(output), reinterpret_cast<float*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<const float*>(extra_add_qk),
reinterpret_cast<float*>(present), use_persistent_softmax);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Loading
Oops, something went wrong.

0 comments on commit fd16085

Please sign in to comment.