diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h index 65fa6fd4ed3b..0bd97055c3eb 100644 --- a/csrc/includes/conversion_utils.h +++ b/csrc/includes/conversion_utils.h @@ -334,6 +334,11 @@ DS_D_INLINE __half2 to(float2 val) { return __float22half2_rn(val); } +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} #ifdef BF16_AVAILABLE // No direct conversion @@ -406,6 +411,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val) return __float22bfloat162_rn(val); } template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ + return __float2bfloat162_rn(val); +} +template <> DS_D_INLINE __nv_bfloat162 to(__half2 val) { return to<__nv_bfloat162>(to(val)); diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 38b57951093d..3baa88c72f84 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -3,6 +3,7 @@ // DeepSpeed Team +#include "conversion_utils.h" #include "inference_cuda_layers.h" #ifndef __HIP_PLATFORM_HCC__ @@ -12,8 +13,9 @@ namespace cg = cooperative_groups; namespace cg = cooperative_groups; -__global__ void apply_rotary_pos_emb(float* mixed_query, - float* key_layer, +template +__global__ void apply_rotary_pos_emb(T* mixed_query, + T* key_layer, unsigned rotary_dim, unsigned seq_len, unsigned seq_offset, @@ -40,8 +42,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = mixed_query[offset + lane]; - float k = key_layer[k_offset + lane]; + float q = conversion::to(mixed_query[offset + lane]); + float k = conversion::to(key_layer[k_offset + lane]); float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -50,59 +52,14 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); - mixed_query[offset + lane] = q; - key_layer[k_offset + lane] = k; + mixed_query[offset + lane] = conversion::to(q); + key_layer[k_offset + lane] = conversion::to(k); lane += WARP_SIZE; } } } -__global__ void apply_rotary_pos_emb(__half* mixed_query, - __half* key_layer, - unsigned rotary_dim, - unsigned seq_len, - unsigned seq_offset, - unsigned num_heads, - unsigned head_size, - unsigned total_count, - int max_out_tokens) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int id = threadIdx.x; - int gid = id >> 5; - int lane = id & 0x1f; - - unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; - unsigned offset = head_id * head_size; - - unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; - unsigned seq_index = head_id % seq_len; - unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; - - if (head_id < total_count) { - while (lane < rotary_dim) { - float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; - inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[k_offset + lane]; - float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); - float q_rot = (q * rotary_sign); - float k_rot = (k * rotary_sign); - q_rot = g.shfl_xor(q_rot, 1); - k_rot = g.shfl_xor(k_rot, 1); - q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); - k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); - - mixed_query[offset + lane] = (__half)q; - key_layer[k_offset + lane] = (__half)k; - - lane += WARP_SIZE; - } - } -} __global__ void apply_rotary_pos_emb1(float* mixed_query, float* key_layer, unsigned rotary_dim, @@ -148,8 +105,10 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, } } } -__global__ void apply_rotary_pos_emb1(__half* mixed_query, - __half* key_layer, + +template +__global__ void apply_rotary_pos_emb1(T* mixed_query, + T* key_layer, unsigned rotary_dim, unsigned seq_len, unsigned seq_offset, @@ -185,8 +144,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, while (lane < rotary_dim) { float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; - float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[k_offset + lane]; + float q = conversion::to(mixed_query[offset + lane]); + float k = conversion::to(key_layer[k_offset + lane]); float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -197,8 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); - mixed_query[offset + lane] = (__half)q; - key_layer[k_offset + lane] = (__half)k; + mixed_query[offset + lane] = conversion::to(q); + key_layer[k_offset + lane] = conversion::to(k); lane += WARP_SIZE; } @@ -256,6 +215,20 @@ template void launch_apply_rotary_pos_emb(float*, bool, cudaStream_t, int); +#ifdef BF16_AVAILABLE +template void launch_apply_rotary_pos_emb<__nv_bfloat16>(__nv_bfloat16*, + __nv_bfloat16*, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + unsigned, + bool, + bool, + cudaStream_t, + int); +#endif template void launch_apply_rotary_pos_emb<__half>(__half*, __half*, unsigned, @@ -269,6 +242,59 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, cudaStream_t, int); +template __global__ void apply_rotary_pos_emb(float* mixed_query, + float* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + int max_out_tokens); + +#ifdef BF16_AVAILABLE +template __global__ void apply_rotary_pos_emb(__nv_bfloat16* mixed_query, + __nv_bfloat16* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + int max_out_tokens); +#endif + +template __global__ void apply_rotary_pos_emb(__half* mixed_query, + __half* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + int max_out_tokens); + +#ifdef BF16_AVAILABLE +template __global__ void apply_rotary_pos_emb1(__nv_bfloat16* mixed_query, + __nv_bfloat16* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + int max_out_tokens); +#endif + +template __global__ void apply_rotary_pos_emb1(__half* mixed_query, + __half* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + int max_out_tokens); /* __global__ void apply_rotary_pos_emb(float* mixed_query, float* key_layer, diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 539d4de60929..fe7b665f0efb 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -3,6 +3,7 @@ // DeepSpeed Team +#include "conversion_utils.h" #include "inference_cuda_layers.h" #define MAX_QUANTIZE_GROUPING 1024 @@ -10,7 +11,8 @@ #define loop_unroll 1 #define loop_unroll_bits 1 -__global__ void dequantize_kernel(float* output, +template +__global__ void dequantize_kernel(T* output, const int8_t* input, const float* qscale, int output_size, @@ -38,40 +40,7 @@ __global__ void dequantize_kernel(float* output, float scale_data = qscale[scale_index]; - output[q_index] = (scale_data * (float)q); - tid += blockDim.x; - } -} - -__global__ void dequantize_kernel(__half* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count) -{ - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = __float2half(scale_data * (float)q); + output[q_index] = conversion::to(scale_data * (float)q); tid += blockDim.x; } } @@ -102,6 +71,18 @@ template void launch_dequantize(float*, unsigned, unsigned, cudaStream_t); + +#ifdef BF16_AVAILABLE +template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); +#endif + template void launch_dequantize<__half>(__half*, const int8_t*, const float*, @@ -120,7 +101,8 @@ __global__ void dequantize_kernel(float* output, { } -__global__ void dequantize_kernel(__half* output, +template +__global__ void dequantize_kernel(T* output, const int8_t* input, const float* qscale, unsigned hidden_dim, @@ -144,12 +126,12 @@ __global__ void dequantize_kernel(__half* output, int8_t* q_int8 = (int8_t*)&q; float2 q_f; - __half* q_h = (__half*)&q_f; + T* q_h = (T*)&q_f; - q_h[0] = __float2half(local_scale * (float)q_int8[0]); - q_h[1] = __float2half(local_scale * (float)q_int8[1]); - q_h[2] = __float2half(local_scale * (float)q_int8[2]); - q_h[3] = __float2half(local_scale * (float)q_int8[3]); + q_h[0] = conversion::to(local_scale * (float)q_int8[0]); + q_h[1] = conversion::to(local_scale * (float)q_int8[1]); + q_h[2] = conversion::to(local_scale * (float)q_int8[2]); + q_h[3] = conversion::to(local_scale * (float)q_int8[3]); output_cast[tid] = q_f; tid += blockDim.x; } @@ -186,6 +168,17 @@ template void launch_dequantize(float*, unsigned, unsigned, cudaStream_t); + +#ifdef BF16_AVAILABLE +template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + cudaStream_t); +#endif + template void launch_dequantize<__half>(__half*, const int8_t*, const float*, @@ -193,3 +186,45 @@ template void launch_dequantize<__half>(__half*, unsigned, unsigned, cudaStream_t); + +template __global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count); + +#ifdef BF16_AVAILABLE +template __global__ void dequantize_kernel(__nv_bfloat16* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count); +#endif + +template __global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count); + +#ifdef BF16_AVAILABLE +template __global__ void dequantize_kernel(__nv_bfloat16* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt); +#endif + +template __global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index d62b135f509b..26b6de0cdc53 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -11,6 +11,11 @@ namespace cg = cooperative_groups; #define MAX_CAP 4 #define MAX_SEQ 2048 +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +using __nv_bfloat162 = __half2; +#endif + inline __device__ float gelu(const float x) { const float sqrt_param = 0.79788456080286535587989211986876f; @@ -66,6 +71,13 @@ void launch_bias_gelu(T* input, } template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_bias_gelu<__nv_bfloat16>(__nv_bfloat16*, + const __nv_bfloat16*, + int, + int, + cudaStream_t); +#endif template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); /* @@ -116,6 +128,13 @@ void launch_bias_add(T* input, } template void launch_bias_add(float*, const float*, int, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_bias_add<__nv_bfloat16>(__nv_bfloat16*, + const __nv_bfloat16*, + int, + int, + cudaStream_t); +#endif template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); __global__ void fused_bias_residual(float* residual, @@ -163,16 +182,19 @@ __global__ void fused_bias_residual(float* residual, } } -__global__ void fused_bias_residual(__half* residual, - const __half* hidden_state, - const __half* attn, - const __half* bias, - const __half* attn_bias, +template +__global__ void fused_bias_residual(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, const int total_count, const int intermediate_size, const float mp_scale, const bool preln) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -187,26 +209,26 @@ __global__ void fused_bias_residual(__half* residual, const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - __half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); - const __half2* hs_half2 = reinterpret_cast(&hs_fl2); - const __half2* attn_half2 = reinterpret_cast(&attn_fl2); - const __half2* bias_half2 = reinterpret_cast(&bias_fl2); - const __half2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - float2 res_low = __half22float2(res_half2[0]); - float2 res_high = __half22float2(res_half2[1]); + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); - const float2 hs_low = __half22float2(hs_half2[0]); - const float2 hs_high = __half22float2(hs_half2[1]); + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); - const float2 attn_low = __half22float2(attn_half2[0]); - const float2 attn_high = __half22float2(attn_half2[1]); + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); - const float2 bias_low = __half22float2(bias_half2[0]); - const float2 bias_high = __half22float2(bias_half2[1]); + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); - const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); - const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); if (preln) { // residual = (residual + attention + bias + attention_bias) * @@ -226,8 +248,8 @@ __global__ void fused_bias_residual(__half* residual, res_high.x = (res_high.x + hs_high.x + bias_high.x); res_high.y = (res_high.y + hs_high.y + bias_high.y); } - res_half2[0] = __float22half2_rn(res_low); - res_half2[1] = __float22half2_rn(res_high); + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); res_fl2_ptr[offset] = res_fl2; } @@ -262,9 +284,43 @@ void launch_bias_residual(T* residual, template void launch_bias_residual< float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_bias_residual<__nv_bfloat16>(__nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + int, + int, + int, + bool, + cudaStream_t); +#endif template void launch_bias_residual< __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t); +#ifdef BF16_AVAILABLE +template __global__ void fused_bias_residual(__nv_bfloat16* residual, + const __nv_bfloat16* hidden_state, + const __nv_bfloat16* attn, + const __nv_bfloat16* bias, + const __nv_bfloat16* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln); +#endif + +template __global__ void fused_bias_residual(__half* residual, + const __half* hidden_state, + const __half* attn, + const __half* bias, + const __half* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln); + __global__ void gptj_residual_add(float* residual, const float* hidden_state, const float* attn, @@ -305,15 +361,18 @@ __global__ void gptj_residual_add(float* residual, } } -__global__ void gptj_residual_add(__half* residual, - const __half* hidden_state, - const __half* attn, - const __half* bias, - const __half* attn_bias, +template +__global__ void gptj_residual_add(T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, const int total_count, const int intermediate_size, const float mp_scale) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float2* res_fl2_ptr = reinterpret_cast(residual); const float2* hs_fl2_ptr = reinterpret_cast(hidden_state); const float2* attn_fl2_ptr = reinterpret_cast(attn); @@ -327,28 +386,28 @@ __global__ void gptj_residual_add(__half* residual, const float2 attn_fl2 = attn_fl2_ptr[offset]; const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - __half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); - const __half2* hs_half2 = reinterpret_cast(&hs_fl2); - const __half2* attn_half2 = reinterpret_cast(&attn_fl2); - const __half2* bias_half2 = reinterpret_cast(&bias_fl2); + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); - float2 res_low = __half22float2(res_half2[0]); - float2 res_high = __half22float2(res_half2[1]); + float2 res_low = conversion::to(res_half2[0]); + float2 res_high = conversion::to(res_half2[1]); - const float2 hs_low = __half22float2(hs_half2[0]); - const float2 hs_high = __half22float2(hs_half2[1]); + const float2 hs_low = conversion::to(hs_half2[0]); + const float2 hs_high = conversion::to(hs_half2[1]); - const float2 attn_low = __half22float2(attn_half2[0]); - const float2 attn_high = __half22float2(attn_half2[1]); + const float2 attn_low = conversion::to(attn_half2[0]); + const float2 attn_high = conversion::to(attn_half2[1]); - const float2 bias_low = __half22float2(bias_half2[0]); - const float2 bias_high = __half22float2(bias_half2[1]); + const float2 bias_low = conversion::to(bias_half2[0]); + const float2 bias_high = conversion::to(bias_half2[1]); if (attn_bias) { const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - const __half2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); - const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + const float2 attn_bias_low = conversion::to(attn_bias_half2[0]); + const float2 attn_bias_high = conversion::to(attn_bias_half2[1]); // residual += attention_bias res_low.x += attn_bias_low.x; res_low.y += attn_bias_low.y; @@ -361,8 +420,8 @@ __global__ void gptj_residual_add(__half* residual, res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; - res_half2[0] = __float22half2_rn(res_low); - res_half2[1] = __float22half2_rn(res_high); + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); res_fl2_ptr[offset] = res_fl2; } @@ -396,6 +455,19 @@ template void launch_gptj_residual_add(float*, int, int, cudaStream_t); + +#ifdef BF16_AVAILABLE +template void launch_gptj_residual_add<__nv_bfloat16>(__nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + int, + int, + int, + cudaStream_t); +#endif + template void launch_gptj_residual_add<__half>(__half*, __half*, __half*, @@ -405,6 +477,27 @@ template void launch_gptj_residual_add<__half>(__half*, int, int, cudaStream_t); + +#ifdef BF16_AVAILABLE +template __global__ void gptj_residual_add(__nv_bfloat16* residual, + const __nv_bfloat16* hidden_state, + const __nv_bfloat16* attn, + const __nv_bfloat16* bias, + const __nv_bfloat16* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale); +#endif + +template __global__ void gptj_residual_add(__half* residual, + const __half* hidden_state, + const __half* attn, + const __half* bias, + const __half* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale); + template __global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) { @@ -455,6 +548,16 @@ template void launch_moe_res_matmul(float* residual, int seq_len, int hidden_dim, cudaStream_t stream); + +#ifdef BF16_AVAILABLE +template void launch_moe_res_matmul(__nv_bfloat16* residual, + __nv_bfloat16* coef, + __nv_bfloat16* mlp_out, + int seq_len, + int hidden_dim, + cudaStream_t stream); +#endif + template void launch_moe_res_matmul(__half* residual, __half* coef, __half* mlp_out, @@ -462,11 +565,11 @@ template void launch_moe_res_matmul(__half* residual, int hidden_dim, cudaStream_t stream); -__global__ void pad_data_kernel(__half* padded_output, - __half* output, - int head_size, - int padded_head_size) +template +__global__ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4* padded_output_cast = reinterpret_cast(padded_output); float4* output_cast = reinterpret_cast(output); int bid = blockIdx.x * (blockDim.y) + threadIdx.y; @@ -474,8 +577,8 @@ __global__ void pad_data_kernel(__half* padded_output, padded_output_cast += (bid * padded_head_size); output_cast += (bid * head_size); float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; if (idx < head_size) @@ -483,12 +586,14 @@ __global__ void pad_data_kernel(__half* padded_output, else padded_output_cast[idx] = ZERO; } + __global__ void pad_data_kernel(float* padded_output, float* output, int head_size, int padded_head_size) { } + template void pad_data(T* padded_output, T* output, @@ -508,6 +613,16 @@ template void pad_data(__half* padded_output, int head_size, int padded_head_size, cudaStream_t stream); + +#ifdef BF16_AVAILABLE +template void pad_data(__nv_bfloat16* padded_output, + __nv_bfloat16* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); +#endif + template void pad_data(float* padded_output, float* output, int bsz, @@ -515,13 +630,28 @@ template void pad_data(float* padded_output, int padded_head_size, cudaStream_t stream); -__global__ void pad_head_seq_kernel(__half* padded_output, - __half* output, +#ifdef BF16_AVAILABLE +template __global__ void pad_data_kernel(__nv_bfloat16* padded_output, + __nv_bfloat16* output, + int head_size, + int padded_head_size); +#endif + +template __global__ void pad_data_kernel(__half* padded_output, + __half* output, + int head_size, + int padded_head_size); + +template +__global__ void pad_head_seq_kernel(T* padded_output, + T* output, int seq_len, int padded_seq_len, int head_size, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4* padded_output_cast = reinterpret_cast(padded_output); float4* output_cast = reinterpret_cast(output); int bsz = blockIdx.x; @@ -530,8 +660,8 @@ __global__ void pad_head_seq_kernel(__half* padded_output, padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; output_cast += (bsz * seq_len + bid) * head_size; float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; @@ -540,6 +670,7 @@ __global__ void pad_head_seq_kernel(__half* padded_output, else padded_output_cast[idx] = ZERO; } + __global__ void pad_head_seq_kernel(float* padded_output, float* output, int seq_len, @@ -548,6 +679,7 @@ __global__ void pad_head_seq_kernel(float* padded_output, int padded_head_size) { } + template void pad_head_seq(T* padded_output, T* output, @@ -563,6 +695,7 @@ void pad_head_seq(T* padded_output, pad_head_seq_kernel<<>>( padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); } + template void pad_head_seq(__half* padded_output, __half* output, int bsz, @@ -571,6 +704,18 @@ template void pad_head_seq(__half* padded_output, int head_size, int padded_head_size, cudaStream_t stream); + +#ifdef BF16_AVAILABLE +template void pad_head_seq(__nv_bfloat16* padded_output, + __nv_bfloat16* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); +#endif + template void pad_head_seq(float* padded_output, float* output, int bsz, @@ -681,4 +826,12 @@ template void launch_fused_bias_geglu(__half*, int, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_fused_bias_geglu(__nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + int, + int, + cudaStream_t); +#endif template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t); diff --git a/csrc/transformer/inference/csrc/layer_norm.cu b/csrc/transformer/inference/csrc/layer_norm.cu index ce3e471b38f9..807d60e740df 100644 --- a/csrc/transformer/inference/csrc/layer_norm.cu +++ b/csrc/transformer/inference/csrc/layer_norm.cu @@ -197,6 +197,16 @@ template void launch_fused_ln(__half*, int, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_fused_ln(__nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + float, + int, + int, + cudaStream_t); +#endif template void launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t); @@ -493,6 +503,19 @@ template void launch_fused_residual_ln(__half*, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_fused_residual_ln(__nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + float, + int, + int, + cudaStream_t); +#endif + template void launch_fused_residual_ln(float*, const float*, const float*, @@ -517,6 +540,20 @@ template void launch_fused_residual_ln_store_pre_ln_res(__half*, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_fused_residual_ln_store_pre_ln_res(__nv_bfloat16*, + __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + float, + int, + int, + cudaStream_t); +#endif + template void launch_fused_residual_ln_store_pre_ln_res(float*, float*, const float*, diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 88d4201ba179..eb9bcf0e4341 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1744,83 +1744,71 @@ bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspac PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)"); - m.def( - "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); - m.def("softmax_context_fp16", - &ds_softmax_context<__half>, - "DeepSpeed attention with fp16 (CUDA)"); m.def("softmax_context_int8", &ds_softmax_context1<__half>, "DeepSpeed attention with int8 (CUDA)"); - m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)"); - m.def("bias_add_fp32", &ds_bias_add, "DeepSpeed Bias Add with fp32 (CUDA)"); - m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)"); - m.def("bias_relu_fp32", &ds_bias_relu, "DeepSpeed ReLU with fp32 (CUDA)"); - m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)"); - m.def("bias_residual_fp32", - &ds_bias_residual, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("bias_residual_fp16", - &ds_bias_residual<__half>, - "DeepSpeed residual-bias add with fp16 (CUDA)"); m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)"); m.def( "_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)"); m.def("layer_norm_residual_store_pre_ln_res", &ds_layer_norm_residual_store_pre_ln_res, "DeepSpeed layer norm + store pre Layernorm residual (CUDA)"); - m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); - m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); - m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); - m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); - m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("vector_matmul_int8", &ds_vector_matmul_int8<__half>, "DeepSpeed vector-MM with int8 (CUDA)"); - m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); - m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); m.def("linear_layer_int8", &ds_linear_layer_int8<__half>, "DeepSpeed linear_layer with int8 (CUDA)"); - m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("residual_add_bias_fp32", - &residual_add_bias, - "DeepSpeed residual add with fp32 (CUDA)"); - m.def("residual_add_bias_fp16", - &residual_add_bias<__half>, - "DeepSpeed residual add with fp16 (CUDA)"); m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("einsum_sec_sm_ecm_fp32", - &einsum_sec_sm_ecm, - "DeepSpeed vector-MM with fp32 (CUDA)"); - - m.def("einsum_sec_sm_ecm_fp16", - &einsum_sec_sm_ecm<__half>, - "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); - m.def("add_padding_fp32", &add_padding, "DeepSpeed residual add with fp32 (CUDA)"); - m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)"); - m.def("pad_transform_fp32", - &padd_add_transform, - "DeepSpeed residual add with fp32 (CUDA)"); - m.def("pad_transform_fp16", - &padd_add_transform<__half>, - "DeepSpeed residual add with fp16 (CUDA)"); - m.def("allocate_workspace_fp32", - &allocate_workspace, - "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); - m.def("allocate_workspace_fp16", - &allocate_workspace<__half>, - "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace"); m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); + +#define DEF_OPS(_name, _dtype) \ + m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \ + m.def("softmax_context_" #_name, \ + &ds_softmax_context<_dtype>, \ + "DeepSpeed attention with _name (CUDA)"); \ + m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \ + m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \ + m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \ + m.def("bias_residual_" #_name, \ + &ds_bias_residual<_dtype>, \ + "DeepSpeed residual-bias add with " #_name " (CUDA)"); \ + m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \ + m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("vector_matmul_" #_name, \ + &ds_vector_matmul<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("linear_layer_" #_name, \ + &ds_linear_layer<_dtype>, \ + "DeepSpeed linear_layer with " #_name " (CUDA)"); \ + m.def("fused_gemm_gelu_" #_name, \ + &fused_gemm_gelu<_dtype>, \ + "DeepSpeed mlp with " #_name " (CUDA)"); \ + m.def("residual_add_bias_" #_name, \ + &residual_add_bias<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("einsum_sec_sm_ecm_" #_name, \ + &einsum_sec_sm_ecm<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (CUDA)"); \ + m.def("add_padding_" #_name, \ + &add_padding<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("pad_transform_" #_name, \ + &padd_add_transform<_dtype>, \ + "DeepSpeed residual add with " #_name " (CUDA)"); \ + m.def("allocate_workspace_" #_name, \ + &allocate_workspace<_dtype>, \ + "DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)") + + DEF_OPS(fp32, float); + DEF_OPS(fp16, __half); +#ifdef BF16_AVAILABLE + DEF_OPS(bf16, __nv_bfloat16); +#endif } diff --git a/csrc/transformer/inference/csrc/relu.cu b/csrc/transformer/inference/csrc/relu.cu index bf6eac269469..fd38d1a88d4b 100644 --- a/csrc/transformer/inference/csrc/relu.cu +++ b/csrc/transformer/inference/csrc/relu.cu @@ -61,4 +61,11 @@ void launch_bias_relu(T* input, } template void launch_bias_relu(float*, const float*, int, int, cudaStream_t); +#ifdef BF16_AVAILABLE +template void launch_bias_relu<__nv_bfloat16>(__nv_bfloat16*, + const __nv_bfloat16*, + int, + int, + cudaStream_t); +#endif template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index 80eff139c3e9..7befdfd51497 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -4,6 +4,7 @@ // DeepSpeed Team #include +#include "conversion_utils.h" #include "inference_cuda_layers.h" #ifndef __HIP_PLATFORM_HCC__ @@ -30,10 +31,10 @@ void CheckCudaErrorAux(const char* file, unsigned line) namespace cg = cooperative_groups; -template -__global__ void attn_softmax_v2(__half* vals, - __half* mask, - __half* alibi, +template +__global__ void attn_softmax_v2(T* vals, + T* mask, + T* alibi, float layer_scale, bool triangular, bool recompute, @@ -53,7 +54,7 @@ __global__ void attn_softmax_v2(__half* vals, float2 low_data[MAX_REG_SIZE]; float2 high_data[MAX_REG_SIZE]; - const __half zero_h = __float2half(0.f); + const T zero_h = conversion::to(0.f); int wid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; @@ -101,73 +102,87 @@ __global__ void attn_softmax_v2(__half* vals, ((data_id + reduceWidth * 3) > window_stride); if (mask && alibi) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset])) + - (__half2float(mask[data_id + mask_offset])) - : minus_infinity; + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; low_data[i].y = - low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + - (__half2float(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + - (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) : minus_infinity; high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + - (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) : minus_infinity; } else if (mask) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(mask[data_id + mask_offset])) - : minus_infinity; - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth])) + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(mask[data_id + mask_offset])) : minus_infinity; + low_data[i].y = + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; high_data[i].x = - high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) - : minus_infinity; + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; high_data[i].y = - high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) - : minus_infinity; + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; } else if (alibi) { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset])) - : minus_infinity; + low_data[i].x = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + : minus_infinity; low_data[i].y = - low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth])) - : minus_infinity; + low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + : minus_infinity; high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) : minus_infinity; high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + - (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) : minus_infinity; } else { - low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + low_data[i].x = low_x_check ? conversion::to(vals[data_id]) * layer_scale : minus_infinity; - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale - : minus_infinity; - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale - : minus_infinity; - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale - : minus_infinity; + low_data[i].y = + low_y_check ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + : minus_infinity; + high_data[i].x = + high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + : minus_infinity; + high_data[i].y = + high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + : minus_infinity; } // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); @@ -225,13 +240,13 @@ __global__ void attn_softmax_v2(__half* vals, for (int i = 0; i < iterations; i++) { int data_id = i * (reduceWidth << 2) + (seq_lane); if (data_id < sequence_length) { - vals[data_id] = __float2half(low_data[i].x / sum); + vals[data_id] = conversion::to(low_data[i].x / sum); if ((data_id + reduceWidth) < sequence_length) - vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum); + vals[data_id + reduceWidth] = conversion::to(low_data[i].y / sum); if ((data_id + reduceWidth * 2) < sequence_length) - vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum); + vals[data_id + reduceWidth * 2] = conversion::to(high_data[i].x / sum); if ((data_id + reduceWidth * 3) < sequence_length) - vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum); + vals[data_id + reduceWidth * 3] = conversion::to(high_data[i].y / sum); } } } @@ -389,23 +404,23 @@ __global__ void attn_softmax_v2(float* vals, } } -#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ - attn_softmax_v2<<>>(vals, \ - mask, \ - alibi, \ - layer_scale, \ - triangular, \ - recompute, \ - local_attention, \ - window_size, \ - total_count, \ - heads, \ - sequence_length, \ - num_seq, \ - head_offset, \ - mask_stride, \ - mp_size, \ - reduce_width); +#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ + attn_softmax_v2<<>>(vals, \ + mask, \ + alibi, \ + layer_scale, \ + triangular, \ + recompute, \ + local_attention, \ + window_size, \ + total_count, \ + heads, \ + sequence_length, \ + num_seq, \ + head_offset, \ + mask_stride, \ + mp_size, \ + reduce_width); template void launch_attn_softmax_v2(T* vals, @@ -488,6 +503,26 @@ template void launch_attn_softmax_v2(float* vals, int mask_stride, int mp_size, cudaStream_t stream); + +#ifdef BF16_AVAILABLE +template void launch_attn_softmax_v2(__nv_bfloat16* vals, + __nv_bfloat16* mask, + __nv_bfloat16* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + int head_offset, + int mask_stride, + int mp_size, + cudaStream_t stream); +#endif + template void launch_attn_softmax_v2(__half* vals, __half* mask, __half* alibi, @@ -504,3 +539,53 @@ template void launch_attn_softmax_v2(__half* vals, int mask_stride, int mp_size, cudaStream_t stream); + +#define DEF_ATTN_SOFTMAX_V2_HALF(_iter) \ + template __global__ void attn_softmax_v2<__half, _iter>(__half * vals, \ + __half * mask, \ + __half * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define DEF_ATTN_SOFTMAX_V2_BF16(_iter) \ + template __global__ void attn_softmax_v2<__nv_bfloat16, _iter>(__nv_bfloat16 * vals, \ + __nv_bfloat16 * mask, \ + __nv_bfloat16 * alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int total_count, \ + int heads, \ + int sequence_length, \ + int num_seq, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + int reduceWidth) + +#define FOREACH_ITERATIONS(cb) \ + cb(1); \ + cb(2); \ + cb(4); \ + cb(8); \ + cb(16); \ + cb(32); \ + cb(64) + +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_HALF); +#ifdef BF16_AVAILABLE +FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_BF16); +#endif diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 8ef0d3289bb5..00d919039b5d 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -6,9 +6,15 @@ #ifndef __HIP_PLATFORM_HCC__ #include #endif +#include "conversion_utils.h" #include "inference_cuda_layers.h" namespace cg = cooperative_groups; +// only used to avoid compilation error due to lack of definition. +#ifndef BF16_AVAILABLE +using __nv_bfloat162 = __half2; +#endif + // Bias add __global__ void bias_add_transform_0213(float* output, @@ -75,11 +81,12 @@ __global__ void bias_add_transform_0213(float* output, #define ATTN_H 3 #define MAX_SEQ_LINE 10 -__global__ void bias_add_transform_0213(__half* output, // q - __half* k_cache, - __half* v_cache, - const __half* vals, // qkv - const __half* bias, +template +__global__ void bias_add_transform_0213(T* output, // q + T* k_cache, + T* v_cache, + const T* vals, // qkv + const T* bias, int hidden_dim, int seq_length, unsigned seq_offset, @@ -91,6 +98,8 @@ __global__ void bias_add_transform_0213(__half* output, // q int head_ext, int max_out_tokens) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; unsigned half_dim = (rotary_dim << 3) >> 1; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -108,8 +117,8 @@ __global__ void bias_add_transform_0213(__half* output, // q float4 vals_arr; float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + T2* vals_half = reinterpret_cast(&vals_arr); + T2* output_half = reinterpret_cast(&output_arr); const float4* vals_vec = reinterpret_cast(vals); float4* output_vec = @@ -129,17 +138,19 @@ __global__ void bias_add_transform_0213(__half* output, // q int lane = d3 & 0x1f; if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { float4 q = vals_vec[d3]; - __half2* q_h = reinterpret_cast<__half2*>(&q); + T2* q_h = reinterpret_cast(&q); if (rotate_every_two) { #pragma unroll for (int o = 0; o < 4; o++) { float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q_data[2]; - q_data[0] = (float)q_h[o].x; - q_data[1] = (float)q_h[o].y; - q_h[o].x = (__half)(-1.0 * q_data[1] * sinf(inv_freq) + q_data[0] * cosf(inv_freq)); - q_h[o].y = (__half)(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); + q_data[0] = conversion::to(q_h[o].x); + q_data[1] = conversion::to(q_h[o].y); + q_h[o].x = conversion::to(-1.0 * q_data[1] * sinf(inv_freq) + + q_data[0] * cosf(inv_freq)); + q_h[o].y = + conversion::to(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); } } output_vec[d3] = q; @@ -188,16 +199,17 @@ void launch_bias_add_transform_0213(float* output, head_ext, max_out_tokens); } + template -void launch_bias_add_transform_0213(T* outputs, - T* vals, - T* vals1, - const T* vals2, +void launch_bias_add_transform_0213(T* output, + T* k_cache, + T* v_cache, + const T* vals, const T* bias, int batch_size, int seq_length, unsigned seq_offset, - int seq_length1, + int all_tokens, int hidden_dim, int heads, int rotary_dim, @@ -205,25 +217,7 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_every_two, cudaStream_t stream, int trans_count, - int max_out_tokens); -template <> -void launch_bias_add_transform_0213<__half>(__half* output, - __half* k_cache, - __half* v_cache, - const __half* vals, - const __half* bias, - int batch_size, - int seq_length, - unsigned seq_offset, - int all_tokens, - int hidden_dim, - int heads, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - cudaStream_t stream, - int trans_count, - int max_out_tokens) + int max_out_tokens) { hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; @@ -246,6 +240,44 @@ void launch_bias_add_transform_0213<__half>(__half* output, max_out_tokens); } +#ifdef BF16_AVAILABLE +template void launch_bias_add_transform_0213(__nv_bfloat16* output, + __nv_bfloat16* k_cache, + __nv_bfloat16* v_cache, + const __nv_bfloat16* vals, + const __nv_bfloat16* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count, + int max_out_tokens); +#endif + +template void launch_bias_add_transform_0213(__half* output, + __half* k_cache, + __half* v_cache, + const __half* vals, + const __half* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count, + int max_out_tokens); + // Bias add __global__ void pad_add_transform_0213(float* output, @@ -258,17 +290,20 @@ __global__ void pad_add_transform_0213(float* output, { } -__global__ void pad_add_transform_0213(__half* output, - const __half* vals, +template +__global__ void pad_add_transform_0213(T* output, + const T* vals, int hidden_dim, int seq_length, int padded_seq_len, int heads, int padded_head_size) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; float4 ZERO; - const __half2 zero_h = __float2half2_rn(0.f); - __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); + const T2 zero_h = conversion::to(0.f); + T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; @@ -301,17 +336,6 @@ __global__ void pad_add_transform_0213(__half* output, output_vec[d3] = ZERO; } -template -void launch_pad_add_transform_0213(T* output, - const T* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream); - // [B S C*H] - > C * [B A S N] template <> void launch_pad_add_transform_0213(float* output, @@ -325,16 +349,17 @@ void launch_pad_add_transform_0213(float* output, cudaStream_t stream) { } -template <> -void launch_pad_add_transform_0213<__half>(__half* output, - const __half* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - cudaStream_t stream) + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) { hidden_dim >>= 3; dim3 block_dim((padded_head_size >> 3), heads, 2); @@ -343,6 +368,28 @@ void launch_pad_add_transform_0213<__half>(__half* output, output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3); } +#ifdef BF16_AVAILABLE +template void launch_pad_add_transform_0213(__nv_bfloat16* output, + const __nv_bfloat16* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); +#endif + +template void launch_pad_add_transform_0213(__half* output, + const __half* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); + // Bias add template __global__ void bias_add_transform_0213(T* output, @@ -394,15 +441,17 @@ __global__ void bias_add_transform_0213(float* output, d2 * d2_out_stride + d3] = outputs; } -template <> -__global__ void bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; @@ -418,9 +467,9 @@ __global__ void bias_add_transform_0213<__half>(__half* output, float4 vals_arr; float4 bias_arr; float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + T2* vals_half = reinterpret_cast(&vals_arr); + T2* bias_half = reinterpret_cast(&bias_arr); + T2* output_half = reinterpret_cast(&output_arr); const float4* vals_vec = reinterpret_cast(vals); const float4* bias_vec = reinterpret_cast(bias); @@ -449,13 +498,16 @@ __global__ void bias_add_transform_0213<__half>(__half* output, output_vec[d3] = output_arr; } -__global__ void bias_add_transform_0213_v2(__half* output, - const __half* vals, - const __half* bias, +template +__global__ void bias_add_transform_0213_v2(T* output, + const T* vals, + const T* bias, int hidden_dim, int seq_length, int heads) { + using T2 = + typename std::conditional::value, __half2, __nv_bfloat162>::type; __shared__ float4 in_data[3072]; int d0_stride = hidden_dim * seq_length; @@ -477,9 +529,9 @@ __global__ void bias_add_transform_0213_v2(__half* output, float4 vals_arr[1]; float4 bias_arr[1]; float4 output_arr[1]; - __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(output_arr); + T2* vals_half = reinterpret_cast(vals_arr); + T2* bias_half = reinterpret_cast(bias_arr); + T2* output_half = reinterpret_cast(output_arr); const float4* vals_vec = reinterpret_cast(vals); const float4* bias_vec = reinterpret_cast(bias); @@ -519,6 +571,22 @@ __global__ void bias_add_transform_0213_v2(__half* output, } } +template __global__ void bias_add_transform_0213_v2(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads); + +#ifdef BF16_AVAILABLE +template __global__ void bias_add_transform_0213_v2(__nv_bfloat16* output, + const __nv_bfloat16* vals, + const __nv_bfloat16* bias, + int hidden_dim, + int seq_length, + int heads); +#endif + template __global__ void transform4d_0213(T* out, const T* in, @@ -560,13 +628,13 @@ __global__ void transform4d_0213(float* out, } } -template <> -__global__ void transform4d_0213<__half>(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) { int d0_stride = hidden_dim * (seq_length / head_ext); int d1_stride = hidden_dim; @@ -594,11 +662,8 @@ __global__ void transform4d_0213<__half>(__half* out, out_vec[d3] = in_vec[d3]; } -__global__ void transform4d_0213_v2(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim) +template +__global__ void transform4d_0213_v2(T* out, const T* in, int heads, int seq_length, int hidden_dim) { __shared__ float4 in_data[3072]; @@ -642,6 +707,20 @@ __global__ void transform4d_0213_v2(__half* out, } } +#ifdef BF16_AVAILABLE +template __global__ void transform4d_0213_v2(__nv_bfloat16* out, + const __nv_bfloat16* in, + int heads, + int seq_length, + int hidden_dim); +#endif + +template __global__ void transform4d_0213_v2(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim); + // 3 * [B A S N] - > [B S C*H] template <> void launch_transform4d_0213(float* out, @@ -660,20 +739,40 @@ void launch_transform4d_0213(float* out, <<>>(out, in, heads, seq_length, hidden_dim, 1); } -template <> -void launch_transform4d_0213<__half>(__half* out, - const __half* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) { hidden_dim >>= 3; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); dim3 block_dims(hidden_dim / heads, (heads / head_ext)); - transform4d_0213<__half> - <<>>(out, in, heads, seq_length, hidden_dim, head_ext); + transform4d_0213<<>>( + out, in, heads, seq_length, hidden_dim, head_ext); } + +#ifdef BF16_AVAILABLE +template void launch_transform4d_0213(__nv_bfloat16* out, + const __nv_bfloat16* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count); +#endif + +template void launch_transform4d_0213(__half* out, + const __half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count); diff --git a/csrc/transformer/inference/includes/inference_cublas_wrappers.h b/csrc/transformer/inference/includes/inference_cublas_wrappers.h index e899ec266d83..38a0f7ab66d2 100644 --- a/csrc/transformer/inference/includes/inference_cublas_wrappers.h +++ b/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #ifndef __HIP_PLATFORM_HCC__ @@ -106,6 +107,7 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } +template #ifdef __HIP_PLATFORM_HCC__ int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, @@ -115,9 +117,9 @@ int cublas_gemm_ex(rocblas_handle handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, rocblas_gemm_algo algo) #else int cublas_gemm_ex(cublasHandle_t handle, @@ -128,13 +130,15 @@ int cublas_gemm_ex(cublasHandle_t handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, cublasGemmAlgo_t algo) #endif { #ifdef __HIP_PLATFORM_HCC__ + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -143,23 +147,24 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, - rocblas_datatype_f16_r, + rocblas_dtype_16, (transa == rocblas_operation_none) ? m : k, (const void*)B, - rocblas_datatype_f16_r, + rocblas_dtype_16, (transb == rocblas_operation_none) ? k : n, (const void*)beta, (void*)C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, (void*)C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, rocblas_datatype_f32_r, algo, 0, 0); #else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasStatus_t status = cublasGemmEx(handle, transa, transb, @@ -168,14 +173,14 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, - CUDA_R_16F, + cublas_dtype_16, (transa == CUBLAS_OP_N) ? m : k, (const void*)B, - CUDA_R_16F, + cublas_dtype_16, (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, (void*)C, - CUDA_R_16F, + cublas_dtype_16, m, CUDA_R_32F, algo); @@ -307,6 +312,7 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, return 0; } +template #ifdef __HIP_PLATFORM_HCC__ int cublas_strided_batched_gemm(rocblas_handle handle, int m, @@ -314,9 +320,9 @@ int cublas_strided_batched_gemm(rocblas_handle handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, rocblas_operation op_A, rocblas_operation op_B, int stride_A, @@ -331,9 +337,9 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int k, const float* alpha, const float* beta, - const __half* A, - const __half* B, - __half* C, + const T* A, + const T* B, + T* C, cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, @@ -344,6 +350,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, #endif { #ifdef __HIP_PLATFORM_HCC__ + constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r + : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -353,20 +361,20 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, - rocblas_datatype_f16_r, + rocblas_dtype_16, (op_A == rocblas_operation_none) ? m : k, stride_A, B, - rocblas_datatype_f16_r, + rocblas_dtype_16, (op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, stride_C, C, - rocblas_datatype_f16_r, + rocblas_dtype_16, m, stride_C, batch, @@ -375,6 +383,7 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, 0, 0); #else + constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasStatus_t status = cublasGemmStridedBatchedEx(handle, op_A, op_B, @@ -383,16 +392,16 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, - CUDA_R_16F, + cublas_dtype_16, (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, - CUDA_R_16F, + cublas_dtype_16, (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, - CUDA_R_16F, + cublas_dtype_16, m, stride_C, batch, diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 8e5c1ae4f44b..e7e835a2dddf 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -8,6 +8,9 @@ #include "ds_kernel_utils.h" #include +#ifdef BF16_AVAILABLE +#include +#endif #include #include #include