Skip to content

Commit

Permalink
Perf/wgpu/reduce dim (tracel-ai#943)
Browse files Browse the repository at this point in the history
* new reduce half working

* surprisingly working

* good on elongated matrix, bad on balanced ones

* working and clean

* autotune not tested, tests fail at non contiguous

* fixed

* autotune tested

* mean dim

* some fixes

* clippy
  • Loading branch information
louisfd authored Nov 13, 2023
1 parent f53ab06 commit 831335a
Show file tree
Hide file tree
Showing 16 changed files with 790 additions and 39 deletions.
4 changes: 4 additions & 0 deletions burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ serial_test = "2.0.0"
[[bench]]
name = "matmul"
harness = false

[[bench]]
name = "reduction"
harness = false
108 changes: 108 additions & 0 deletions burn-wgpu/benches/reduction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::kernel::reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory};
use burn_wgpu::WgpuDevice;
use burn_wgpu::{AutoGraphicsApi, Wgpu};
use derive_new::new;
use std::marker::PhantomData;

use burn_wgpu::GraphicsApi;

type WTensor<G, const D: usize> = Tensor<Wgpu<G, f32, i32>, D>;

#[derive(new)]
struct ReduceBenchmark<B: Backend, F, const D: usize> {
shape: Shape<D>,
dim: usize,
num_repeats: usize,
device: B::Device,
reduce: PhantomData<F>,
}

trait ReduceFunction<G: GraphicsApi, const D: usize> {
fn run(input: WTensor<G, D>, dim: usize) -> WTensor<G, D>;
}

impl<F, const D: usize, G> Benchmark for ReduceBenchmark<Wgpu<G, f32, i32>, F, D>
where
F: ReduceFunction<G, D>,
G: GraphicsApi,
{
type Args = WTensor<G, D>;

fn name(&self) -> String {
format!(
"{:?} {:?} dim={:?}",
std::any::type_name::<F>(),
self.shape.dims,
self.dim
)
}

fn num_samples(&self) -> usize {
10
}

fn execute(&self, input: Self::Args) {
for _ in 0..self.num_repeats {
F::run(input.clone(), self.dim);
}
}

fn prepare(&self) -> Self::Args {
WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device)
}

fn sync(&self) {
Wgpu::<G, f32, i32>::sync(&self.device)
}
}

macro_rules! bench_reduce {
($benchmark:ident, $reduce_name:ident, $func:expr) => {
struct $reduce_name {}
impl<G: GraphicsApi, const D: usize> ReduceFunction<G, D> for $reduce_name {
fn run(input: WTensor<G, D>, dim: usize) -> WTensor<G, D> {
let input = input.into_primitive();
let output = init_reduce_output(&input, dim);
Tensor::from_primitive($func(input, output, dim))
}
}
type $benchmark<const D: usize> =
ReduceBenchmark<Wgpu<AutoGraphicsApi, f32, i32>, $reduce_name, D>;
};
}

bench_reduce!(SumDimBenchmark, SumDim, sum_dim);
bench_reduce!(
SumDimSharedMemoryBenchmark,
SumDimSharedMemory,
sum_dim_shared_memory
);

#[allow(dead_code)]
/// Runs the benchmarks for wgpu matmul implementations
pub fn bench(device: &WgpuDevice) {
let num_repeats = 3;
let shape = Shape::new([50, 8000, 50]);
let dim = 1;

macro_rules! run_reduce_benchmark {
($benchmark:ident) => {
run_benchmark($benchmark::new(
shape.clone(),
dim,
num_repeats,
device.clone(),
));
};
}

run_reduce_benchmark!(SumDimSharedMemoryBenchmark);
run_reduce_benchmark!(SumDimBenchmark);
}

fn main() {
bench(&WgpuDevice::BestAvailable)
}
8 changes: 7 additions & 1 deletion burn-wgpu/src/compute/tune_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@ use std::fmt::Display;

use burn_compute::tune::AutotuneKey;

use crate::kernel::matmul::MatmulAutotuneKey;
use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};

#[derive(Hash, Eq, PartialEq, Debug, Clone)]
/// Key for all autotune-enabled operations
pub enum WgpuAutotuneKey {
/// Key for matmul operation
Matmul(MatmulAutotuneKey),
/// Key for sum_dim operations
SumDim(ReduceAutotuneKey),
/// Key for mean_dim operations
MeanDim(ReduceAutotuneKey),
}

impl Display for WgpuAutotuneKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions burn-wgpu/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ mod clamp;
mod comparison;
mod index;
mod mask;
mod reduction;
mod source;
mod unary;
mod unary_scalar;
Expand All @@ -26,10 +25,11 @@ pub mod matmul;
pub mod pool;
/// Pseudo-random number generator kernels
pub mod prng;
/// Reduction algorithms
pub mod reduce;

pub(crate) use cat::*;
pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;
pub(crate) use mask::*;
pub(crate) use reduction::*;
22 changes: 22 additions & 0 deletions burn-wgpu/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::{element::WgpuElement, tensor::WgpuTensor};

/// Creates an empty output tensor with reduce output shape
pub fn init_reduce_output<E: WgpuElement, const D: usize>(
input: &WgpuTensor<E, D>,
reduce_dim: usize,
) -> WgpuTensor<E, D> {
let mut shape_out = input.shape.clone();
shape_out.dims[reduce_dim] = 1;

// Create output handle
let num_elems_output = shape_out.num_elements();
let handle = input
.client
.empty(num_elems_output * core::mem::size_of::<E>());
WgpuTensor::new(
input.client.clone(),
input.device.clone(),
shape_out.clone(),
handle,
)
}
9 changes: 9 additions & 0 deletions burn-wgpu/src/kernel/reduce/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mod base;
mod reduction;
mod reduction_shared_memory;
mod tune;

pub use base::*;
pub use reduction::*;
pub use reduction_shared_memory::*;
pub use tune::*;
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
compute::StaticKernel,
element::WgpuElement,
kernel::{
build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
WORKGROUP_DEFAULT,
},
kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;

kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
kernel_wgsl!(ReductionDimRaw, "../template/reduction/reduce_dim.wgsl");
kernel_wgsl!(ReductionArgsRaw, "../template/reduction/args.wgsl");
kernel_wgsl!(
RecursiveSumRaw,
"../../template/reduction/recursive_sum.wgsl"
);
kernel_wgsl!(ReductionDimRaw, "../../template/reduction/reduce_dim.wgsl");
kernel_wgsl!(ReductionArgsRaw, "../../template/reduction/args.wgsl");

pub struct ArgsMax;
pub struct ArgsMin;
pub struct SumDim;
pub struct MeanDim;
pub(crate) struct ArgsMax;
pub(crate) struct ArgsMin;
pub(crate) struct SumDim;
pub(crate) struct MeanDim;

impl StaticKernelSource for SumDim {
fn source() -> SourceTemplate {
Expand Down Expand Up @@ -79,37 +87,29 @@ pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTenso
/// Execute the sum dim kernel.
pub fn sum_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
output: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<SumDim, E, D>(input, dim)
reduction_dim::<SumDim, E, D>(input, output, dim)
}

/// Execute the mean dim kernel.
pub fn mean_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
output: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<MeanDim, E, D>(input, dim)
reduction_dim::<MeanDim, E, D>(input, output, dim)
}

fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
output: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
let handle = input.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
input.client.clone(),
input.device.clone(),
shape_out,
handle,
);

let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT),
);

let mut info = build_info(&[&input, &output]);
Expand Down Expand Up @@ -174,7 +174,10 @@ fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, con
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use crate::{
kernel::reduce::init_reduce_output,
tests::{ReferenceBackend, TestBackend},
};
use burn_tensor::{Distribution, Int, Tensor};

#[test]
Expand All @@ -192,10 +195,13 @@ mod tests {
fn reduction_sum_dim_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let reduce_dim = 1;
let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim);

let val = Tensor::<TestBackend, 2>::from_primitive(reduction_dim::<SumDim, f32, 2>(
tensor.into_primitive(),
1,
output,
reduce_dim,
));
let val_ref = tensor_ref.sum_dim(1);

Expand Down
Loading

0 comments on commit 831335a

Please sign in to comment.