Skip to content

Commit

Permalink
[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kerne…
Browse files Browse the repository at this point in the history
…ls.cu and cross_entropy.cu code style (hpcaitech#634)
  • Loading branch information
ExtremeViscent authored and binmakeswell committed Apr 6, 2022
1 parent ad1e7ab commit 8a5d526
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x];

if (target_tid == padding_idx) {
Expand Down
42 changes: 28 additions & 14 deletions colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

uint8_t m[4];

Expand All @@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
Expand Down Expand Up @@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);

if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);

if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);

if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
__syncthreads();

if (threadIdx.y == 0) {
Expand Down

0 comments on commit 8a5d526

Please sign in to comment.