Skip to content

Commit

Permalink
Perf/wgpu/matmul vec4rhs (tracel-ai#914)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Oct 31, 2023
1 parent 96524d4 commit 8742d31
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 1,022 deletions.
44 changes: 12 additions & 32 deletions burn-wgpu/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@ use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::kernel::matmul::init_matmul_output;
use burn_wgpu::{kernel::matmul::vec4_primitive, WgpuDevice};
use burn_wgpu::kernel::matmul::vec4::matmul_tiling_2d_vec4;
use burn_wgpu::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs;
use burn_wgpu::WgpuDevice;
use burn_wgpu::{AutoGraphicsApi, Wgpu};
use derive_new::new;
use std::marker::PhantomData;

use burn_wgpu::{
kernel::matmul::{
contiguous, contiguous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
tile, tile_vectorized,
},
kernel::matmul::{matmul_mem_coalescing_default, matmul_naive_default},
GraphicsApi,
};

Expand Down Expand Up @@ -85,37 +84,21 @@ macro_rules! bench_matmul {
MatmulBenchmark<Wgpu<AutoGraphicsApi, f32, i32>, $matmul_name, D>;
};
}

bench_matmul!(NaiveMatmulBenchmark, NaiveMatmul, matmul_naive_default);
bench_matmul!(
MemCoalescingMatmulBenchmark,
MemCoalescingMatmul,
matmul_mem_coalescing_default
);
bench_matmul!(
Tiling2DMatmulContiguousBenchmark,
Tiling2DMatmulContiguous,
contiguous::matmul_tiling_2d_default
);
bench_matmul!(
Tiling2DMatmulTileBenchmark,
Tiling2DMatmulTile,
tile::matmul_tiling_2d_default
);
bench_matmul!(
Tiling2DMatmulTileVectorizedBenchmark,
Tiling2DMatmulTileVectorized,
tile_vectorized::matmul_tiling_2d_default
);
bench_matmul!(
Tiling2DMatmulContiguousVectorizedBenchmark,
Tiling2DMatmulContiguousVectorized,
contiguous_vectorized::matmul_tiling_2d_default
Tiling2DMatmulVec4LHSBenchmark,
Tiling2DMatmulVec4LHS,
matmul_tiling_2d_vec4_lhs
);
bench_matmul!(
Tiling2DMatmulVec4PrimitiveBenchmark,
Tiling2DMatmulVec4Primitive,
vec4_primitive::matmul_tiling_2d_vec4_primitive_default
Tiling2DMatmulVec4Benchmark,
Tiling2DMatmulVec4,
matmul_tiling_2d_vec4
);

#[allow(dead_code)]
Expand All @@ -142,11 +125,8 @@ pub fn bench(device: &WgpuDevice) {
}
run_matmul_benchmark!(NaiveMatmulBenchmark);
run_matmul_benchmark!(MemCoalescingMatmulBenchmark);
run_matmul_benchmark!(Tiling2DMatmulContiguousBenchmark);
run_matmul_benchmark!(Tiling2DMatmulTileBenchmark);
run_matmul_benchmark!(Tiling2DMatmulTileVectorizedBenchmark);
run_matmul_benchmark!(Tiling2DMatmulContiguousVectorizedBenchmark);
run_matmul_benchmark!(Tiling2DMatmulVec4PrimitiveBenchmark);
run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark);
run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark);
}

fn main() {
Expand Down
Loading

0 comments on commit 8742d31

Please sign in to comment.