Skip to content

Commit

Permalink
[cherry-pick] [ROCM] bugfix for bilinear_interp_v2_grad (PaddlePaddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Sep 28, 2021
1 parent cea0bc2 commit c576169
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
@@ -1198,7 +1198,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<T><<<config.block_per_grid, 512, 0,
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpFw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
@@ -1606,9 +1611,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false;
#ifndef __HIPCC__
optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6))
? true
: ((in_h == 1 && in_w == 1) ? true : false);
#endif

if (optimize_flag & is_nchw) {
KeBilinearInterpBwShareMemory<
@@ -1623,7 +1630,12 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
ratio_h, ratio_w, align_type_value, is_nchw);
}
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<T><<<config.block_per_grid, 512, 0,
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpBw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);

0 comments on commit c576169

Please sign in to comment.