Skip to content

Commit

Permalink
Implicit GEMM optimizations/bug fixes (#2499)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 18, 2024
1 parent 6d105ea commit f64914b
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 242 deletions.
395 changes: 285 additions & 110 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ bincode = { version = "2.0.0-rc.3", features = [
#
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.7.0", default-features = false }
cfg-if = "1.0.0"
derive-new = { version = "0.7.0", default-features = false }

blas-src = { version = "0.10.0", default-features = false }
half = { version = "2.4.1", features = [
Expand Down Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8f4861ebe577065e2209ee94724c05b514e1b860" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8f4861ebe577065e2209ee94724c05b514e1b860" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand All @@ -166,4 +166,4 @@ cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features
tracel-xtask = { version = "~1.1" }

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
debug = 0 # Speed up compilation time and not necessary.
151 changes: 102 additions & 49 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ use burn_tensor::{
Shape,
};
use cmma::{Matrix, MatrixIdent, MatrixLayout};
use cubecl::{cube, prelude::*, Compiler, CubeCount, CubeDim, Feature};
use cubecl::{
cube,
ir::{Elem, FloatKind},
prelude::*,
Compiler, CubeCount, CubeDim, Feature,
};
use half::f16;

use crate::{
kernel::{into_contiguous, slice},
kernel::{into_contiguous, slice, slice_assign},
ops::{
numeric::{empty_device, zeros_device},
permute,
Expand All @@ -30,9 +35,17 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
bias: Option<JitTensor<R, F>>,
options: ConvOptions<2>,
) -> JitTensor<R, F> {
let is_tf32 = F::as_elem() == Elem::Float(FloatKind::F32)
&& input
.client
.properties()
.feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32)));

let k_target = if is_tf32 { 8 } else { 16 };

let [batch_size, in_channels, height, width] = input.shape.dims();
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w);
let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w, k_target);
let padded_out_channels = out_channels.div_ceil(16) * 16;

let out_h = calculate_conv_output_size(
Expand Down Expand Up @@ -66,12 +79,13 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
"Requirements for implicit GEMM not met:
- CMMA must be available
- `groups` must be 1
- subcube size must be non-variable (might not hold on Intel)
"
);
}

let input = into_contiguous(permute(input, &[0, 2, 3, 1]));
let weight = into_contiguous(permute(weight, &[0, 2, 3, 1]));
let weight = into_contiguous(permute(weight, &[2, 3, 1, 0]));

let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]);
let out = empty_device(input.client.clone(), input.device.clone(), out_shape);
Expand All @@ -81,18 +95,19 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
let gemm_n = padded_out_channels as u32;
let gemm_k = (pad_in_channels * pad_kh * pad_kw) as u32;

let slice_size = pad_kh * pad_kw * pad_in_channels;

let (cmma_m, cmma_n, cmma_k) =
find_cmma_size::<R, f16, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();
find_cmma_size::<R, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();

let slice_size = pad_kh * pad_kw * pad_in_channels;

let cube_dim_x = 128;
let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);

let input_tile_size = cmma_m * cmma_k;
let weight_tile_size = cmma_k * cmma_n;

let warp_size = 32;
let topology = input.client.properties().hardware_properties();
let warp_size = topology.plane_size_min;
let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size;

let supported_vecs = R::supported_line_sizes();
Expand All @@ -102,12 +117,19 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(

let weight_elems_per_thread = weight_tile_size / warp_size;
let weight_vectorization =
find_common_vec(in_channels, weight_elems_per_thread, supported_vecs);
find_common_vec(out_channels, weight_elems_per_thread, supported_vecs);

let has_bias = bias.is_some();
let bias = bias.unwrap_or_else(|| {
zeros_device(input.client.clone(), input.device.clone(), Shape::new([1]))
});
let bias = match bias {
Some(bias) if out_channels == padded_out_channels => bias,
Some(bias) => {
let shape = Shape::new([padded_out_channels]);
let padded_bias = zeros_device(bias.client.clone(), bias.device.clone(), shape);
#[allow(clippy::single_range_in_vec_init)]
slice_assign(padded_bias, &[0..out_channels], bias)
}
None => empty_device(input.client.clone(), input.device.clone(), Shape::new([1])),
};

let settings = GemmSettings {
cmma_m,
Expand Down Expand Up @@ -138,7 +160,12 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(

let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1);

implicit_gemm_kernel::launch::<F, f16, R>(
let launch = match is_tf32 {
false => implicit_gemm_kernel::launch::<F, f16, R>,
true => implicit_gemm_kernel::launch::<F, tf32, R>,
};

launch(
&input.client,
cube_count,
cube_dim,
Expand Down Expand Up @@ -303,7 +330,7 @@ fn implicit_gemm_kernel<F: Float, FMat: Float>(
let mut out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);

if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
execute_gemm(
execute_gemm::<F, FMat>(
input,
weight,
bias,
Expand Down Expand Up @@ -396,7 +423,7 @@ fn make_matrices<F: Float, FAcc: Float>(
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::ColMajor,
MatrixLayout::RowMajor,
)
},
acc,
Expand All @@ -422,8 +449,7 @@ fn execute_gemm<F: Float, FMat: Float>(

let matrices = make_matrices::<FMat, F>(g_settings, has_bias);
if has_bias {
let n = UNIT_POS_Y * cmma_n + pos.global_n;
let bias_tile = bias.slice(n, n + cmma_n);
let bias_tile = bias.slice(pos.global_n, pos.global_n + cmma_n);
cmma::load_with_layout(&matrices.acc, &bias_tile, 0, MatrixLayout::RowMajor);
}

Expand All @@ -440,8 +466,8 @@ fn execute_gemm<F: Float, FMat: Float>(
load_weight_tile(weight, weight_tile, dims, pos, k, g_settings, k_settings);

// Run CMMA
cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_n);
cmma::load(&matrices.a, &input_tile.to_slice(), cmma_k);
cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_k);

cmma::execute::<FMat, FMat, F, F>(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc);
}
Expand Down Expand Up @@ -573,29 +599,31 @@ fn load_weight_tile<F: Float, FMat: Float>(
let cmma_filter_tile_size = cmma_k * cmma_n;
let elems_per_thread = cmma_filter_tile_size / warp_size;
let start = pos.intra_warp_unit_idx * elems_per_thread;
let abs_slice_col = pos.global_n + (start / cmma_k); // Row of the matrix the slice is on

let n_in_bounds = !check_n || abs_slice_col < weight.shape(0);
let col_idx = abs_slice_col * weight.stride(0);
let global_k = start / cmma_n + k;

let (k_idx, k_in_bounds) = if check_k {
let channel = global_k % dims.pad_channels;
let kernel_x = global_k / dims.pad_channels % dims.pad_kw;
let kernel_y = global_k / (dims.pad_channels * dims.pad_kw);
let k_in_bounds =
!check_k || (channel < weight.shape(2) && kernel_x < kernel_w && kernel_y < kernel_h);
let idx =
kernel_y * weight.stride(0) + kernel_x * weight.stride(1) + channel * weight.stride(2);
(idx, k_in_bounds)
} else {
(global_k * weight.stride(2), true)
};

#[unroll]
for n in range_stepped(0, elems_per_thread, vec) {
let n = n + start;
// Compute where in the slice we are starting
let rel_slice_row = n % cmma_k; // Relative row (0 - 15)
let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on

let (idx, k_in_bounds) = if check_k {
let channel = abs_slice_row % dims.pad_channels;
let kernel_x = abs_slice_row / dims.pad_channels % dims.pad_kw;
let kernel_y = abs_slice_row / (dims.pad_channels * dims.pad_kw);
let k_in_bounds = !check_k
|| (channel < weight.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h);
let idx = col_idx + kernel_y * weight.stride(1) + kernel_x * weight.stride(2) + channel;
(idx, k_in_bounds)
} else {
(col_idx + abs_slice_row, true)
};

let global_n = (n % cmma_n) + pos.global_n;
let n_in_bounds = !check_n || global_n < weight.shape(3);

let idx = k_idx + global_n;

let value = FMat::cast_from(weight[idx / vec]);
let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0));

Expand All @@ -617,29 +645,46 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
out_w: usize,
client: &ComputeClient<R::Server, R::Channel>,
) -> bool {
let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]);
let cmma_k = match (
E::as_elem(),
client
.properties()
.feature_enabled(Feature::Type(tf32::as_elem())),
) {
(Elem::Float(FloatKind::F32), true) => 8,
_ => 16,
};

let (in_channels, kernel_h, kernel_w) =
padded_k(in_channels, kernel_size[0], kernel_size[1], cmma_k);
let batch_size = padded_batch_size(batch_size, out_h, out_w);
let out_channels = out_channels.div_ceil(16) * 16;

let gemm_m = batch_size * out_h * out_w;
let gemm_n = out_channels;
let gemm_k = in_channels * kernel_h * kernel_w;

let size = find_cmma_size::<R, f16, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
let size = find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);

if let Some((cmma_m, cmma_k, cmma_n)) = size {
let warps_per_cube = 8;

let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
let topology = client.properties().hardware_properties();
let not_intel = topology.plane_size_min >= 32;

<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1
<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1 && not_intel
} else {
false
}
}

fn padded_k(in_channels: usize, kernel_h: usize, kernel_w: usize) -> (usize, usize, usize) {
let target = 16;
fn padded_k(
in_channels: usize,
kernel_h: usize,
kernel_w: usize,
target: usize,
) -> (usize, usize, usize) {
if in_channels * kernel_h * kernel_w % target == 0 {
return (in_channels, kernel_h, kernel_w);
}
Expand All @@ -659,41 +704,49 @@ fn padded_k(in_channels: usize, kernel_h: usize, kernel_w: usize) -> (usize, usi

fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize {
let out_size = out_h * out_w;
let target = if out_size % 2 == 0 {
let target = if out_size.is_power_of_two() || out_size % 16 == 0 {
(16usize).div_ceil(out_size)
} else {
16
};
batch_size.div_ceil(target) * target
}

fn find_cmma_size<R: JitRuntime, F: Float, FAcc: Float>(
fn find_cmma_size<R: JitRuntime, F: Float>(
client: &ComputeClient<R::Server, R::Channel>,
gemm_m: u32,
gemm_k: u32,
gemm_n: u32,
) -> Option<(u32, u32, u32)> {
supported_cmma_sizes::<R, F, FAcc>(client)
supported_cmma_sizes::<R, F>(client)
.into_iter()
.find(|(m, k, n)| {
gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0
})
.map(|(m, k, n)| (m as u32, n as u32, k as u32))
}

fn supported_cmma_sizes<R: JitRuntime, F: Float, FAcc: Float>(
fn supported_cmma_sizes<R: JitRuntime, F: Float>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<(u8, u8, u8)> {
let requested_sizes = [(16, 16, 16), (32, 16, 8), (8, 16, 32)];
let (requested_sizes, matrix_elem) = match (
F::as_elem(),
client
.properties()
.feature_enabled(Feature::Type(tf32::as_elem())),
) {
(Elem::Float(FloatKind::F32), true) => (vec![(16, 8, 16)], tf32::as_elem()),
_ => (vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)], f16::as_elem()),
};

requested_sizes
.iter()
.copied()
.filter(|(m, k, n)| {
client.properties().feature_enabled(Feature::Cmma {
a: F::as_elem(),
b: F::as_elem(),
c: FAcc::as_elem(),
a: matrix_elem,
b: matrix_elem,
c: F::as_elem(),
m: *m,
k: *k,
n: *n,
Expand Down
18 changes: 8 additions & 10 deletions crates/burn-jit/src/kernel/interpolate/bicubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*};

use crate::{tensor::JitTensor, FloatElement, JitRuntime};

#[cube(launch_unchecked)]
#[cube(launch)]
fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
Expand Down Expand Up @@ -128,15 +128,13 @@ pub(crate) fn interpolate_bicubic_launch<R: JitRuntime, E: FloatElement>(
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);

unsafe {
interpolate_bicubic_kernel::launch_unchecked::<E, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg(1),
output.as_tensor_arg(1),
)
};
interpolate_bicubic_kernel::launch::<E, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg(1),
output.as_tensor_arg(1),
);

output
}
Loading

0 comments on commit f64914b

Please sign in to comment.