Skip to content

Commit

Permalink
[Fix] Prevent various OOB accesses and discontiguous buffer bugs (tra…
Browse files Browse the repository at this point in the history
…cel-ai#2467)

* Fix various OOB accesses and discontiguous buffer bugs

* Fix docs

* Fix typo in fusion

* Re-enable implicit GEMM

* Revert last commit

* Fix implicit GEMM and reenable it

* Optimize bias loading
  • Loading branch information
wingertge authored Nov 11, 2024
1 parent 8f6535a commit b4fa1fc
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 98 deletions.
6 changes: 3 additions & 3 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1721,9 +1721,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_bool_tensor::<B>(&self.desc.input);
let output = B::bool_expand(input, self.desc.shape.into());
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
let input = handles.get_int_tensor::<B>(&self.desc.input);
let output = B::int_expand(input, self.desc.shape.into());
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}

Expand Down
25 changes: 12 additions & 13 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor,
};

use super::into_contiguous;

#[cube]
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
/// Execute a binary operation.
Expand Down Expand Up @@ -66,9 +68,7 @@ pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOp<C>>(
scalar: C,
output: &mut Tensor<Line<C>>,
) {
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
if ABSOLUTE_POS >= output.len() {
return;
}

Expand Down Expand Up @@ -176,9 +176,7 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(

rhs
} else {
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

Expand All @@ -199,9 +197,13 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
}

pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
tensor: JitTensor<R, E>,
mut tensor: JitTensor<R, E>,
scalar: E,
) -> JitTensor<R, E> {
if !tensor.is_contiguous_buffer() {
tensor = into_contiguous(tensor);
}

// Vectorization is only enabled when the last dimension is contiguous.
let ndims = tensor.shape.num_dims();
let vectorization_factor =
Expand All @@ -225,13 +227,10 @@ pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(

tensor
} else {
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
let output = empty_device(
tensor.client.clone(),
buffer,
tensor.shape.clone(),
tensor.device.clone(),
tensor.strides.clone(),
tensor.shape.clone(),
);

kernel_scalar_binop::launch::<E, O, R>(
Expand Down
21 changes: 11 additions & 10 deletions crates/burn-jit/src/kernel/comparison.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor,
};

use super::into_contiguous;

#[cube]
pub(crate) trait ComparisonOp<C: Numeric>: 'static + Send + Sync {
/// Execute a comparison operation.
Expand Down Expand Up @@ -169,9 +171,7 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(

JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides)
} else {
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
let output =
JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

Expand All @@ -192,9 +192,13 @@ pub(crate) fn launch_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
}

pub(crate) fn launch_scalar_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>>(
tensor: JitTensor<R, E>,
mut tensor: JitTensor<R, E>,
scalar: E,
) -> JitTensor<R, u32> {
if !tensor.is_contiguous_buffer() {
tensor = into_contiguous(tensor);
}

let ndims = tensor.shape.num_dims();
// Vectorization is only enabled when the last dimension is contiguous.
let vectorization_factor =
Expand Down Expand Up @@ -225,13 +229,10 @@ pub(crate) fn launch_scalar_cmp<R: JitRuntime, E: JitElement, O: ComparisonOp<E>
tensor.strides,
)
} else {
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<u32>());
let output = JitTensor::new(
let output = empty_device(
tensor.client.clone(),
buffer,
tensor.shape.clone(),
tensor.device.clone(),
tensor.strides.clone(),
tensor.shape.clone(),
);

kernel_scalar_cmp::launch::<E, O, R>(
Expand Down
4 changes: 1 addition & 3 deletions crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,10 @@ fn col2im_kernel<F: Float>(
args: &Col2ImArgs,
#[comptime] has_bias: bool,
) {
if ABSOLUTE_POS > image.len() {
if ABSOLUTE_POS >= image.len() {
return;
}

let _ = bias[0]; // Keep in bind group

let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w;
let im_y = ABSOLUTE_POS / image.stride(2) % image.shape(2) + args.pad_h;
let ch_im = ABSOLUTE_POS / image.stride(1) % image.shape(1);
Expand Down
64 changes: 8 additions & 56 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(

let input_tile_size = cmma_m * cmma_k;
let weight_tile_size = cmma_k * cmma_n;
let acc_tile_size = cmma_m * cmma_n;

let warp_size = 32;
let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size;
Expand All @@ -104,8 +103,6 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
let weight_elems_per_thread = weight_tile_size / warp_size;
let weight_vectorization =
find_common_vec(in_channels, weight_elems_per_thread, supported_vecs);
let bias_elems_per_thread = acc_tile_size / warp_size;
let bias_vectorization = find_common_vec(out_channels, bias_elems_per_thread, supported_vecs);

let has_bias = bias.is_some();
let bias = bias.unwrap_or_else(|| {
Expand Down Expand Up @@ -147,7 +144,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
cube_dim,
input.as_tensor_arg(input_vectorization),
weight.as_tensor_arg(weight_vectorization),
bias.as_tensor_arg(bias_vectorization),
bias.as_tensor_arg(1),
out.as_tensor_arg(1),
DimensionsLaunch::new(
ScalarArg::new(gemm_m),
Expand Down Expand Up @@ -262,11 +259,11 @@ struct Matrices<F: Float, FAcc: Float> {
}

#[allow(clippy::collapsible_else_if)]
#[cube(launch_unchecked, launch)]
#[cube(launch)]
fn implicit_gemm_kernel<F: Float, FMat: Float>(
input: &Tensor<Line<F>>,
weight: &Tensor<Line<F>>,
bias: &Tensor<Line<F>>,
bias: &Tensor<F>,
out: &mut Tensor<F>,
dims: &Dimensions,
args: &ConvArgs,
Expand Down Expand Up @@ -410,7 +407,7 @@ fn make_matrices<F: Float, FAcc: Float>(
fn execute_gemm<F: Float, FMat: Float>(
input: &Tensor<Line<F>>,
weight: &Tensor<Line<F>>,
bias: &Tensor<Line<F>>,
bias: &Tensor<F>,
out: &mut SliceMut<F>,
input_tile: &mut SliceMut<FMat>,
weight_tile: &mut SliceMut<FMat>,
Expand All @@ -420,25 +417,14 @@ fn execute_gemm<F: Float, FMat: Float>(
#[comptime] g_settings: GemmSettings,
#[comptime] k_settings: ConvSettings,
) {
let GemmSettings {
cmma_m,
cmma_n,
cmma_k,
warps_per_cube,
..
} = g_settings;
let GemmSettings { cmma_n, cmma_k, .. } = g_settings;
let has_bias = k_settings.has_bias;

let matrices = make_matrices::<FMat, F>(g_settings, has_bias);
if has_bias {
let mut smem_bias = SharedMemory::new(cmma_m * cmma_n * warps_per_cube);
load_bias_tile(bias, &mut smem_bias, pos, g_settings);
cmma::load_with_layout(
&matrices.acc,
smem_bias.as_slice(),
cmma_n,
MatrixLayout::RowMajor,
);
let n = UNIT_POS_Y * cmma_n + pos.global_n;
let bias_tile = bias.slice(n, n + cmma_n);
cmma::load_with_layout(&matrices.acc, bias_tile, 0, MatrixLayout::RowMajor);
}

// Loop over the K-dimension
Expand Down Expand Up @@ -620,40 +606,6 @@ fn load_weight_tile<F: Float, FMat: Float>(
}
}

#[cube]
fn load_bias_tile<F: Float>(
bias: &Tensor<Line<F>>,
tile: &mut SharedMemory<F>,
pos: &Positions,
#[comptime] gemm_settings: GemmSettings,
) {
let GemmSettings {
cmma_n,
cmma_m,
warp_size,
..
} = gemm_settings;

let vec = vectorization_of(bias);
let cmma_acc_tile_size = cmma_m * cmma_n;
let elems_per_thread = cmma_acc_tile_size / warp_size;
let start = pos.intra_warp_unit_idx * elems_per_thread;
let bias_tile_start = pos.cube_linear_warp_idx * cmma_acc_tile_size;

#[unroll]
for n in range_stepped(0, elems_per_thread, vec) {
let n = n + start;

let row = n % cmma_n + pos.global_n;
let value = bias[row / vec];

#[unroll]
for i in 0..vec {
tile[bias_tile_start + n + i] = value[i];
}
}
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
batch_size: usize,
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use cubecl::{

use crate::{
kernel::{
conv::{batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col},
conv::{
batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col,
conv2d_implicit_gemm,
},
prng::random_uniform,
},
tensor::JitTensor,
Expand Down Expand Up @@ -39,7 +42,7 @@ pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
}

#[tune(
operations(conv2d_direct, conv2d_im2col),
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
create_key = create_key,
should_run = should_run
)]
Expand Down
11 changes: 9 additions & 2 deletions crates/burn-jit/src/kernel/reduce/shared/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ pub fn reduce_dim_shared_kernel<
#[comptime] smem_size: u32,
#[comptime] elems_per_thread: u32,
#[comptime] divisible_shape: bool,
#[comptime] check_out: bool,
) {
let reduce_group_id = CUBE_POS;

if check_out && reduce_group_id >= output.len() {
return;
}

let stride_reduce_dim_input = input.stride(dim);
let shape_reduce_dim_input = input.shape(dim);

let reduce_group_id = CUBE_POS;

let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS);

let mut index_offset = 0;
Expand Down Expand Up @@ -100,6 +105,7 @@ pub fn reduce_dim_shared<
f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;

let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32;
let check_out = (cube_count_x * cube_count_y) as usize != num_elems_output;

unsafe {
reduce_dim_shared_kernel::launch_unchecked::<RD, EI, EO, R>(
Expand All @@ -112,6 +118,7 @@ pub fn reduce_dim_shared<
cube_dim.num_elems(),
elems_per_thread,
divisible_shape,
check_out,
)
};

Expand Down
10 changes: 9 additions & 1 deletion crates/burn-jit/src/kernel/reduce/subcube/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ pub fn reduce_dim_subcube_kernel<
#[comptime] smem_size: u32,
#[comptime] elems_per_thread: u32,
#[comptime] divisible_shape: bool,
#[comptime] check_out: bool,
) {
let reduce_group_id = CUBE_POS;

if check_out && reduce_group_id >= output.len() {
return;
}

let stride_reduce_dim_input = input.stride(dim);
let shape_reduce_dim_input = input.shape(dim);

let should_unroll = elems_per_thread <= 8;

let reduce_group_id = CUBE_POS;
let warp_id = UNIT_POS / SUBCUBE_DIM;

let mut shared_memory = RD::init_shared(smem_size);
Expand Down Expand Up @@ -112,6 +118,7 @@ pub fn reduce_dim_subcube<
f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;

let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32;
let check_out = (cube_count_x * cube_count_y) as usize != num_elems_output;
let smem_size = cube_dim.num_elems() / warp_size;

unsafe {
Expand All @@ -125,6 +132,7 @@ pub fn reduce_dim_subcube<
smem_size,
elems_per_thread,
divisible_shape,
check_out,
)
};

Expand Down
8 changes: 3 additions & 5 deletions crates/burn-jit/src/kernel/unary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, unexpanded,
Expand Down Expand Up @@ -66,7 +66,7 @@ where
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
let is_contiguous = tensor.is_contiguous();

if tensor.can_mut() && is_contiguous {
if tensor.can_mut() && tensor.is_contiguous_buffer() {
unary_kernel::launch::<E, O, R>(
&client,
cube_count,
Expand All @@ -80,12 +80,10 @@ where

tensor
} else {
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new_contiguous(
let output = empty_device(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
buffer,
);

unary_kernel::launch::<E, O, R>(
Expand Down
Loading

0 comments on commit b4fa1fc

Please sign in to comment.