Skip to content

Commit

Permalink
[ROCm] Bugfix of BFloat16-float conversion and Add FastGelu Kernel fo…
Browse files Browse the repository at this point in the history
…r AMD (microsoft#10557)

* bf16 bugfix on amd

* enable fastgelu ut on amd
centwang authored Feb 16, 2022
1 parent f22cd3a commit ceb1e2b
Showing 4 changed files with 86 additions and 37 deletions.
22 changes: 22 additions & 0 deletions include/onnxruntime/core/framework/float16.h
Original file line number Diff line number Diff line change
@@ -50,6 +50,20 @@ struct BFloat16 {
inline ORT_HOST_DEVICE BFloat16(float v) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
val = __bfloat16_as_ushort(__float2bfloat16(v));
#elif defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
if (v != v) { // isnan
val = UINT16_C(0x7FC0);
} else {
union {
uint32_t U32;
float F32;
};

F32 = v;
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
#else
ORT_IF_CONSTEXPR(endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
@@ -63,6 +77,14 @@ struct BFloat16 {
inline ORT_HOST_DEVICE float ToFloat() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
#elif defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
float result = 0;
uint32_t tmp = val;
tmp <<= 16;
float* tempRes = reinterpret_cast<float*>(&tmp);
result = *tempRes;
return result;
#else
float result;
char* const first = reinterpret_cast<char*>(&result);
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
@@ -113,6 +113,16 @@ bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int i
return HIP_CALL(hipPeekAtLastError());
}

template <>
bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<BFloat16, blockSize>), dim3(gridSize), dim3(blockSize), 0, stream,
A, B, C, input_length, bias_length, input, bias, output);
return HIP_CALL(hipPeekAtLastError());
}

} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);

// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization);
@@ -166,7 +166,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>

// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
87 changes: 52 additions & 35 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
@@ -40,41 +40,51 @@ const std::vector<float> GetExpectedResult(const std::vector<float>& input_data,
return ComputeGelu(add_bias_data);
}

static void RunFastGeluTest(
const std::vector<float>& input_data,
const std::vector<float>& bias_data,
const std::vector<float>& output_data,
const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& bias_dims,
const std::vector<int64_t>& output_dims,
bool has_bias = true,
bool use_float16 = false) {
#if defined(USE_CUDA) || defined(USE_ROCM)
static void RunFastGeluGpuTest(const std::vector<float>& input_data, const std::vector<float>& bias_data,
const std::vector<float>& output_data, const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& bias_dims, const std::vector<int64_t>& output_dims,
bool has_bias = true, bool use_float16 = false) {
#ifdef USE_CUDA
// Test CUDA operator.
int min_cuda_architecture = use_float16 ? 530 : 0;
if (HasCudaEnvironment(min_cuda_architecture)) {
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

if (use_float16) {
tester.AddInput<MLFloat16>("X", input_dims, ToFloat16(input_data));
if (has_bias) {
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
}
tester.AddOutput<MLFloat16>("Y", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("X", input_dims, input_data);
if (has_bias) {
tester.AddInput<float>("bias", bias_dims, bias_data);
}
tester.AddOutput<float>("Y", output_dims, output_data);
if (use_float16) {
tester.AddInput<MLFloat16>("X", input_dims, ToFloat16(input_data));
if (has_bias) {
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
tester.AddOutput<MLFloat16>("Y", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("X", input_dims, input_data);
if (has_bias) {
tester.AddInput<float>("bias", bias_dims, bias_data);
}
tester.AddOutput<float>("Y", output_dims, output_data);
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

static void RunFastGeluCpuTest(const std::vector<float>& input_data, const std::vector<float>& bias_data,
const std::vector<float>& output_data, const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& bias_dims, const std::vector<int64_t>& output_dims,
bool has_bias = true) {
// Test CPU operator: only float32 is implemented for FastGelu CPU.
if (nullptr != DefaultCpuExecutionProvider().get() && !use_float16) {
if (nullptr != DefaultCpuExecutionProvider().get()) {
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

tester.AddInput<float>("X", input_dims, input_data);
@@ -107,7 +117,10 @@ static void RunFastGeluTest(
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;
RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
#if defined(USE_CUDA) || defined(USE_ROCM)
RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
#endif
RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
}

TEST(FastGeluTest, FastGeluWithNullInput) {
@@ -152,6 +165,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) {
RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size);
}

// CUDA and ROCm only for Float16 and BFloat16 type.
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBiasFloat16) {
int batch_size = 1;
int sequence_length = 2;
@@ -172,7 +187,7 @@ TEST(FastGeluTest, FastGeluWithBiasFloat16) {
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;

RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true);
RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true);
}

TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
@@ -194,17 +209,17 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) {
std::vector<int64_t> bias_dims = {};
std::vector<int64_t> output_dims = input_dims;

RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
}

// CUDA only, ROCM has not been supported yet
#ifdef USE_CUDA
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
}
#endif
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

int batch_size = 1;
@@ -235,12 +250,14 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
tester.AddOutput<BFloat16>("Y", output_dims, f_Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif



} // namespace test
} // namespace onnxruntime

0 comments on commit ceb1e2b

Please sign in to comment.