Skip to content

Commit

Permalink
[Optimization] Add custom NCHW to NHWC kernel for implicit GEMM (#2530)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 25, 2024
1 parent fe3e43a commit 0b614b7
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 24 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
8 changes: 7 additions & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::{
FloatElement, IntElement, JitRuntime,
};

use super::nchw_to_nhwc;

/// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available.
///
/// * `input` - The input feature map
Expand Down Expand Up @@ -84,7 +86,11 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
);
}

let input = into_contiguous(permute(input, &[0, 2, 3, 1]));
// If input is contiguous NCHW, use custom transpose kernel
let input = match input.is_contiguous() {
true => nchw_to_nhwc::<R, F>(input),
false => into_contiguous(permute(input, &[0, 2, 3, 1])),
};
let weight = into_contiguous(permute(weight, &[2, 3, 1, 0]));

let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]);
Expand Down
202 changes: 202 additions & 0 deletions crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use burn_tensor::Shape;
use cubecl::{prelude::*, CubeCount, CubeDim};

use crate::{
ops::{max_vectorization, numeric::empty_device},
tensor::JitTensor,
JitElement, JitRuntime,
};

/// Efficiently transpose an NCHW tensor to NHWC for use in kernels that prefer NHWC for performance.
/// Faster than `into_contiguous`, but specialized only for this specific permutation.
///
/// # Arguments
///
/// * `input` - The input in NCHW format
///
/// # Output
///
/// The input in NHWC format
///
pub fn nchw_to_nhwc<R: JitRuntime, E: JitElement>(input: JitTensor<R>) -> JitTensor<R> {
let tiles_per_block = 8;
let warp_size = 32;
let tile_dim = 16;

let [batch_size, in_c, h, w] = input.shape.dims();
let hw = h * w;

let out_shape = Shape::new([batch_size, h, w, in_c]);
let out = empty_device::<R, E>(input.client.clone(), input.device.clone(), out_shape);

let tiles_channel = in_c.div_ceil(tile_dim) as u32;
let tiles_hw = hw.div_ceil(tile_dim) as u32;

let block_tiles_y = Ord::min(tiles_channel.next_power_of_two(), tiles_per_block);
let block_tiles_x = Ord::min(tiles_per_block / block_tiles_y, tiles_hw);

let cube_count_y = tiles_channel.div_ceil(block_tiles_y);
let cube_count_x = tiles_hw.div_ceil(block_tiles_x);
let cube_count_z = batch_size as u32;

let config = ComptimeConfig {
tiles_x: block_tiles_x,
warps_per_cube: tiles_per_block,
tile_dim: tile_dim as u32,
warp_size,
num_banks: 32,
};

let cube_dim = CubeDim {
x: block_tiles_x * warp_size,
y: block_tiles_y,
z: 1,
};
let cube_count = CubeCount::Static(cube_count_x, cube_count_y, cube_count_z);

let in_vec = max_vectorization(&input);
let out_vec = R::supported_line_sizes()
.iter()
.copied()
.find(|vec| in_c % *vec as usize == 0)
.unwrap_or(1);

unsafe {
nchw_to_nhwc_kernel::launch_unchecked::<E, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg::<E>(in_vec),
out.as_tensor_arg::<E>(out_vec),
ScalarArg::new(hw as u32),
config,
)
};

out
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
struct ComptimeConfig {
tiles_x: u32,
warps_per_cube: u32,
tile_dim: u32,
warp_size: u32,
num_banks: i32,
}

#[cube(launch_unchecked)]
fn nchw_to_nhwc_kernel<E: Numeric>(
input: &Tensor<Line<E>>,
out: &mut Tensor<Line<E>>,
shape_hw: u32,
#[comptime] config: ComptimeConfig,
) {
let ComptimeConfig {
tiles_x,
warps_per_cube,
tile_dim,
warp_size,
num_banks,
} = config;

let tile_elems = tile_dim * tile_dim;

let unit_pos = UNIT_POS;
let intra_warp_unit_idx = unit_pos % 32;
let batch = CUBE_POS_Z;

if batch >= input.shape(0) {
return;
}

let batch_offset = batch * input.stride(0);

let warp_id = plane_broadcast(unit_pos / 32, 0);
let warp_id_x = warp_id / CUBE_DIM_Y;

let tile_x = CUBE_POS_X * tiles_x + warp_id_x;
let tile_y = ABSOLUTE_POS_Y;

let mut shared = SharedMemory::<E>::new(warps_per_cube * tile_elems);
let shared_start = warp_id * tile_elems;

let base_hw = tile_x * tile_dim;
let base_c = tile_y * tile_dim;

let elems_per_unit = tile_elems / warp_size;
let unit_start = intra_warp_unit_idx * elems_per_unit;

let mat_hw_start = unit_start % tile_dim;

let mat_c = unit_start / tile_dim;
let channel = base_c + mat_c;
let offset = channel * input.stride(1) + batch_offset;

let input_vec = input.line_size();
let out_vec = out.line_size();
let in_max = input.buffer_len() - 1;

let channels = input.shape(1);

let mat_offset_base = shared_start + mat_c * tile_dim;

#[unroll]
for hw in range_stepped(0, elems_per_unit, input_vec) {
let mat_hw = mat_hw_start + hw;
let hw = base_hw + mat_hw;
let offset = Min::min((offset + hw) / input_vec, in_max);
let value = input[offset];

let mat_idx = mat_offset_base + mat_hw;

#[unroll]
for v in 0..input_vec {
let shared_idx = swizzle(mat_idx + v, num_banks);
shared[shared_idx] = value[v];
}
}

sync_units();

let mat_hw = mat_c;
let hw = base_hw + mat_hw;

if hw >= shape_hw {
return;
}

let mat_c_start = mat_hw_start;
let offset = hw * out.stride(2) + batch_offset;
let mat_base = shared_start + mat_hw;

#[unroll]
for ch in range_stepped(0, elems_per_unit, out_vec) {
let mat_c = mat_c_start + ch;
let ch = base_c + mat_c;

let mat_idx = mat_base + mat_c * tile_dim;
let mut value = Line::empty(out_vec);
let offset = (offset + ch) / out_vec;

#[unroll]
for v in 0..out_vec {
let shared_idx = swizzle(mat_idx + v * tile_dim, num_banks);
value[v] = shared[shared_idx];
}

if ch < channels {
out[offset] = value;
}
}
}

#[cube]
pub fn swizzle(offset: u32, #[comptime] bank_count: i32) -> u32 {
let num_bits = comptime!(i32::BITS - bank_count.leading_zeros() - 1);
let bit_mask = (1 << num_bits) - 1;
let yyy_mask = bit_mask << (num_bits);
let mask_shift = num_bits;

offset ^ ((offset & yyy_mask) >> mask_shift)
}
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ mod col2im;
mod direct;
mod im2col;
mod implicit_gemm;
mod layout_swap;
mod transpose_direct;

mod tune;

pub use base::*;
pub use col2im::*;
pub use direct::*;
pub use im2col::*;
pub use implicit_gemm::*;
pub use layout_swap::*;
pub use transpose_direct::*;
pub use tune::*;
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ pub(crate) use conv_transpose3d::*;
pub(crate) use deform_conv2d::*;
pub(crate) use deform_conv_transpose2d::*;

pub use conv2d::{conv2d, conv_transpose2d, Conv2dStrategy, ConvTranspose2dStrategy};
pub use conv2d::{conv2d, conv_transpose2d, nchw_to_nhwc, Conv2dStrategy, ConvTranspose2dStrategy};
Loading

0 comments on commit 0b614b7

Please sign in to comment.