Skip to content

Commit

Permalink
[Optimization] Implicit gemm rewrite (#2545)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 29, 2024
1 parent 42e7c1f commit a5624c1
Show file tree
Hide file tree
Showing 41 changed files with 1,830 additions and 141 deletions.
26 changes: 14 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
hip-jit = ["burn/hip-jit"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand Down
147 changes: 138 additions & 9 deletions backend-comparison/benches/conv2d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::hint::black_box;

use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct Conv2dBenchmark<B: Backend> {
suffix: &'static str,
input_shape: Shape,
weight_shape: Shape,
bias_shape: Shape,
Expand All @@ -16,7 +19,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);

fn name(&self) -> String {
"conv2d".into()
format!("conv2d-{}", self.suffix)
}

fn shapes(&self) -> Vec<Vec<usize>> {
Expand Down Expand Up @@ -50,6 +53,10 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
fn sync(&self) {
B::sync(&self.device)
}

fn num_samples(&self) -> usize {
40
}
}

#[allow(dead_code)]
Expand All @@ -75,6 +82,7 @@ fn bench<B: Backend>(
let groups = 1;
let options = ConvOptions::new(strides, padding, dilations, groups);
let benchmark = Conv2dBenchmark::<B> {
suffix: "input_16x512x512_weight_16x3x3_stride_1",
input_shape: [batch_size, channels_in, height_in, width_in].into(),
weight_shape: [
channels_out,
Expand All @@ -88,14 +96,135 @@ fn bench<B: Backend>(
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
let conv1 = Conv2dBenchmark::<B> {
suffix: "input_3x227x227_weight_96x11x11_stride_4",
input_shape: [batch_size, 3, 227, 227].into(),
weight_shape: [96, 3, 11, 11].into(),
bias_shape: [96].into(),
options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv2 = Conv2dBenchmark::<B> {
suffix: "input_3x231x231_weight_96x11x11_stride_4",
input_shape: [batch_size, 3, 231, 231].into(),
weight_shape: [96, 3, 11, 11].into(),
bias_shape: [96].into(),
options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv3 = Conv2dBenchmark::<B> {
suffix: "input_3x227x227_weight_64x7x7_stride_2",
input_shape: [batch_size, 3, 227, 227].into(),
weight_shape: [64, 3, 7, 7].into(),
bias_shape: [64].into(),
options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv4 = Conv2dBenchmark::<B> {
suffix: "input_64x224x224_weight_64x7x7_stride_2",
input_shape: [batch_size, 64, 224, 224].into(),
weight_shape: [64, 64, 7, 7].into(),
bias_shape: [64].into(),
options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv5 = Conv2dBenchmark::<B> {
suffix: "input_96x24x24_weight_256x5x5_stride_1",
input_shape: [batch_size, 96, 24, 24].into(),
weight_shape: [256, 96, 5, 5].into(),
bias_shape: [256].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv6 = Conv2dBenchmark::<B> {
suffix: "input_256x12x12_weight_512x3x3_stride_1",
input_shape: [batch_size, 256, 12, 12].into(),
weight_shape: [512, 256, 3, 3].into(),
bias_shape: [512].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv7 = Conv2dBenchmark::<B> {
suffix: "input_3x224x224_weight_64x3x3_stride_1",
input_shape: [batch_size, 3, 224, 224].into(),
weight_shape: [64, 3, 3, 3].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv8 = Conv2dBenchmark::<B> {
suffix: "input_64x112x112_weight_128x3x3_stride_1",
input_shape: [batch_size, 64, 112, 112].into(),
weight_shape: [128, 64, 3, 3].into(),
bias_shape: [128].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv9 = Conv2dBenchmark::<B> {
suffix: "input_64x56x56_weight_64x3x3_stride_1",
input_shape: [batch_size, 64, 56, 56].into(),
weight_shape: [64, 64, 3, 3].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv10 = Conv2dBenchmark::<B> {
suffix: "input_128x28x28_weight_128x3x3_stride_1",
input_shape: [batch_size, 128, 28, 28].into(),
weight_shape: [128, 128, 3, 3].into(),
bias_shape: [128].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv11 = Conv2dBenchmark::<B> {
suffix: "input_256x14x14_weight_256x3x3_stride_1",
input_shape: [batch_size, 256, 14, 14].into(),
weight_shape: [256, 256, 3, 3].into(),
bias_shape: [256].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv12 = Conv2dBenchmark::<B> {
suffix: "input_512x7x7_weight_512x3x3_stride_1",
input_shape: [batch_size, 512, 7, 7].into(),
weight_shape: [512, 512, 3, 3].into(),
bias_shape: [512].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv13 = Conv2dBenchmark::<B> {
suffix: "input_96x224x224_weight_64x1x1_stride_1",
input_shape: [batch_size, 96, 224, 224].into(),
weight_shape: [64, 96, 1, 1].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let benches = vec![
benchmark, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8, conv9, conv10, conv11,
conv12, conv13,
];
let mut results = Vec::new();

for bench in benches {
let result = black_box(run_benchmark(bench));
results.push(result);
}

save::<B>(results, device, feature_name, url, token).unwrap();
}

fn main() {
Expand Down
42 changes: 16 additions & 26 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn::tensor::{backend::Backend, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

Expand All @@ -21,17 +21,13 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
vec![self.shape_lhs.dims.clone(), self.shape_rhs.dims.clone()]
}

fn num_samples(&self) -> usize {
10
}

fn execute(&self, (lhs, rhs): Self::Args) {
lhs.clone().matmul(rhs.clone());
lhs.matmul(rhs);
}

fn prepare(&self) -> Self::Args {
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device);
let lhs = Tensor::zeros(self.shape_lhs.clone(), &self.device);
let rhs = Tensor::zeros(self.shape_rhs.clone(), &self.device);

(lhs, rhs)
}
Expand All @@ -48,24 +44,18 @@ fn bench<B: Backend>(
url: Option<&str>,
token: Option<&str>,
) {
const D: usize = 3;
let batch_size = 8;
let m = 2048;
let k = 2048;
let n = 2048;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();

let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)]
.into_iter()
.map(|(b, m, n, k)| {
let shape_lhs = [b, m, k].into();
let shape_rhs = [b, k, n].into();

MatmulBenchmark::<B, 3>::new(shape_lhs, shape_rhs, device.clone())
})
.map(run_benchmark)
.collect();

save::<B>(benchmarks, device, feature_name, url, token).unwrap();
}

fn main() {
Expand Down
Loading

0 comments on commit a5624c1

Please sign in to comment.