Skip to content

Commit

Permalink
Introduce autotuning to conv2d and conv_transpose2d with a new `i…
Browse files Browse the repository at this point in the history
…m2col`/`GEMM` algorithm (tracel-ai#2287)
  • Loading branch information
wingertge authored Sep 23, 2024
1 parent 2c8514c commit 97af8c6
Show file tree
Hide file tree
Showing 35 changed files with 1,806 additions and 95 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 19 additions & 18 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ version.workspace = true

[features]
# we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
cuda-jit = ["burn/cuda-jit"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -24,7 +25,6 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn/cuda-jit"]

[dependencies]
arboard = { workspace = true }
Expand All @@ -33,11 +33,13 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }
derive-new = { workspace = true }
dirs = { workspace = true }
github-device-flow = { workspace = true }
os_info = { workspace = true }
half = { workspace = true }
indicatif = { workspace = true }
os_info = { workspace = true }
percent-encoding = { workspace = true }
rand = { workspace = true }
reqwest = { workspace = true, features = ["blocking", "json"] }
Expand All @@ -48,69 +50,68 @@ strum_macros = { workspace = true }
sysinfo = { workspace = true, features = ["serde"] }
wgpu = { workspace = true }
wsl = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }

[dev-dependencies]
rstest = { workspace = true }
serial_test = { workspace = true }

[[bench]]
name = "unary"
harness = false
name = "unary"

[[bench]]
name = "binary"
harness = false
name = "binary"

[[bench]]
harness = false
name = "max-pool2d"
path = "benches/max_pool2d.rs"
harness = false

[[bench]]
harness = false
name = "conv-transpose2d"
path = "benches/conv_transpose2d.rs"
harness = false

[[bench]]
harness = false
name = "conv-transpose3d"
path = "benches/conv_transpose3d.rs"
harness = false

[[bench]]
name = "conv2d"
harness = false
name = "conv2d"

[[bench]]
name = "conv3d"
harness = false
name = "conv3d"

[[bench]]
name = "matmul"
harness = false
name = "matmul"

[[bench]]
name = "data"
harness = false
name = "data"

[[bench]]
name = "load-record"
harness = false
name = "load-record"
path = "benches/load_record.rs"

[[bench]]
harness = false
name = "custom-gelu"
path = "benches/custom_gelu.rs"
harness = false

[[bench]]
harness = false
name = "resnet50"
path = "benches/resnet.rs"
harness = false

[[bench]]
name = "autodiff"
harness = false
name = "autodiff"

[[bin]]
name = "burnbench"
Expand Down
17 changes: 10 additions & 7 deletions backend-comparison/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fs;
use std::path::Path;
use std::process::Command;

const MODELS_DIR: &str = "/tmp/models";
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git";

// Patch resnet code (remove pretrained feature code)
Expand Down Expand Up @@ -224,16 +223,18 @@ where
}

fn main() {
let models_dir = std::env::temp_dir().join("models");
let models_dir = models_dir.as_path();
// Checkout ResNet code from models repo
let models_dir = Path::new(MODELS_DIR);
let models_dir = Path::new(models_dir);
if !models_dir.join(".git").exists() {
run("git", |command| {
command
.arg("clone")
.arg("--depth=1")
.arg("--no-checkout")
.arg(MODELS_REPO)
.arg(MODELS_DIR)
.arg(models_dir)
});

run("git", |command| {
Expand Down Expand Up @@ -266,10 +267,12 @@ fn main() {
let source_path = models_dir.join("resnet-burn").join("resnet").join("src");
let dest_path = Path::new(&out_dir);

for file in fs::read_dir(source_path).unwrap() {
let source_file = file.unwrap().path();
let dest_file = dest_path.join(source_file.file_name().unwrap());
fs::copy(source_file, dest_file).expect("should copy file successfully");
if let Ok(source_path) = fs::read_dir(source_path) {
for file in source_path {
let source_file = file.unwrap().path();
let dest_file = dest_path.join(source_file.file_name().unwrap());
fs::copy(source_file, dest_file).expect("should copy file successfully");
}
}

// Delete cloned repository contents
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ macro_rules! bench_on_backend {
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};

bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}
};
}
Expand Down
20 changes: 11 additions & 9 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit"
version.workspace = true

[features]
autotune = []
default = ["autotune", "std", "fusion", "cubecl/default"]
std = ["cubecl/std"]
doc = ["default"]
autotune = []
template = []
fusion = ["burn-fusion"]
export_tests = [
"burn-tensor-testgen",
"serial_test",
Expand All @@ -25,32 +22,37 @@ export_tests = [
"burn-ndarray",
"fusion",
]
fusion = ["burn-fusion"]
std = ["cubecl/std"]
template = []

[dependencies]
cubecl = { workspace = true, features = ["linalg"] }
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl"] }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
"cubecl",
] }
cubecl = { workspace = true, features = ["linalg"] }

bytemuck = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
spin = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }

# Template
serde = { workspace = true }
text_placeholder = { workspace = true, features = ["struct_context"] }

hashbrown = { workspace = true }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.15.0", optional = true }
hashbrown = { workspace = true }

# When exporting tests
serial_test = { workspace = true, optional = true }
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, optional = true }
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true }
serial_test = { workspace = true, optional = true }

[package.metadata.docs.rs]
features = ["doc"]
6 changes: 3 additions & 3 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,14 +511,14 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operator::AtomicCompareAndSwap(_op) => {
// Nothing to do.
}
Operator::Magnitude(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operator::AtomicCompareAndSwap(_op) => {
// Nothing to do.
}
},
Operation::Procedure(proc) => {
match proc {
Expand Down
125 changes: 125 additions & 0 deletions crates/burn-jit/src/kernel/conv/conv2d/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use burn_tensor::{
ops::{ConvOptions, ConvTransposeOptions},
TensorData,
};

use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime};

#[cfg(feature = "autotune")]
use super::conv2d_autotune;
use super::{
conv2d_direct, conv2d_im2col, conv_transpose2d_autotune, conv_transpose2d_col2im,
conv_transpose2d_direct, implicit_gemm::conv2d_implicit_gemm,
};

/// The strategy to be used when launching a convolution kernel.
pub enum Conv2dStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
Gemm,
/// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
/// has constraints on tensor shape.
ImplicitGemm,
}

impl Default for Conv2dStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return Conv2dStrategy::Autotune;

// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
Conv2dStrategy::Direct
}
}

/// The strategy to be used when launching a conv_transpose kernel.
pub enum ConvTranspose2dStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
Gemm,
}

impl Default for ConvTranspose2dStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return ConvTranspose2dStrategy::Autotune;

// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
ConvTranspose2dStrategy::Direct
}
}

/// Perform a 2D convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
///
pub fn conv2d<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
strategy: Conv2dStrategy,
) -> JitTensor<R, E, 4> {
match strategy {
Conv2dStrategy::Direct => conv2d_direct::<R, E, I>(input, weight, bias, options),
#[cfg(feature = "autotune")]
Conv2dStrategy::Autotune => conv2d_autotune::<R, E, I>(input, weight, bias, options),
Conv2dStrategy::Gemm => conv2d_im2col::<R, E, I>(input, weight, bias, options),
Conv2dStrategy::ImplicitGemm => {
conv2d_implicit_gemm::<R, E, I>(input, weight, bias, options)
}
}
}

/// Perform a 2D convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
///
pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
strategy: ConvTranspose2dStrategy,
) -> JitTensor<R, E, 4> {
match strategy {
ConvTranspose2dStrategy::Direct => {
conv_transpose2d_direct::<R, E, I>(input, weight, bias, options)
}
#[cfg(feature = "autotune")]
ConvTranspose2dStrategy::Autotune => {
conv_transpose2d_autotune::<R, E, I>(input, weight, bias, options)
}
ConvTranspose2dStrategy::Gemm => {
conv_transpose2d_col2im::<R, E, I>(input, weight, bias, options)
}
}
}

#[allow(unused)]
pub(crate) fn debug_data<R: JitRuntime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
) -> TensorData {
let bytes = tensor.client.read(tensor.handle.binding());
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}
Loading

0 comments on commit 97af8c6

Please sign in to comment.