Skip to content

Commit

Permalink
add bf16 cuda kernel support (microsoft#3092)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
dc3671 and tjruwase authored Apr 22, 2023
1 parent 2f8d384 commit 297cd9e
Show file tree
Hide file tree
Showing 11 changed files with 859 additions and 407 deletions.
10 changes: 10 additions & 0 deletions csrc/includes/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ DS_D_INLINE __half2 to(float2 val)
{
return __float22half2_rn(val);
}
template <>
DS_D_INLINE __half2 to(float val)
{
return __float2half2_rn(val);
}

#ifdef BF16_AVAILABLE
// No direct conversion
Expand Down Expand Up @@ -406,6 +411,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
return __float22bfloat162_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat162 to(float val)
{
return __float2bfloat162_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat162 to(__half2 val)
{
return to<__nv_bfloat162>(to<float2>(val));
Expand Down
140 changes: 83 additions & 57 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

// DeepSpeed Team

#include "conversion_utils.h"
#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
Expand All @@ -12,8 +13,9 @@
namespace cg = cooperative_groups;
namespace cg = cooperative_groups;

__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
template <typename T>
__global__ void apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
Expand All @@ -40,8 +42,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float q = conversion::to<float>(mixed_query[offset + lane]);
float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -50,59 +52,14 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = conversion::to<T>(k);

lane += WARP_SIZE;
}
}
}

__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);

int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;

unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;

lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
Expand Down Expand Up @@ -148,8 +105,10 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,

template <typename T>
__global__ void apply_rotary_pos_emb1(T* mixed_query,
T* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
Expand Down Expand Up @@ -185,8 +144,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float q = conversion::to<float>(mixed_query[offset + lane]);
float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -197,8 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);

mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = conversion::to<T>(k);

lane += WARP_SIZE;
}
Expand Down Expand Up @@ -256,6 +215,20 @@ template void launch_apply_rotary_pos_emb<float>(float*,
bool,
cudaStream_t,
int);
#ifdef BF16_AVAILABLE
template void launch_apply_rotary_pos_emb<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
cudaStream_t,
int);
#endif
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
Expand All @@ -269,6 +242,59 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
cudaStream_t,
int);

template __global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);

#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif

template __global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);

#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb1(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif

template __global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
Expand Down
Loading

0 comments on commit 297cd9e

Please sign in to comment.