Skip to content

Commit

Permalink
Refactor/wgpu/memco (tracel-ai#1340)
Browse files Browse the repository at this point in the history
nathanielsimard authored Feb 22, 2024
1 parent bff4961 commit b256c04
Showing 35 changed files with 1,657 additions and 796 deletions.
32 changes: 0 additions & 32 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions crates/burn-wgpu/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -113,9 +113,9 @@ pub fn bench(device: &WgpuDevice) {
const D: usize = 3;
let num_repeats = 3;
let batch_size = 3;
let m = 1007;
let k = 1023;
let n = 1005;
let m = 1024;
let k = 1024;
let n = 1024;
let shape_lhs = Shape::new([batch_size, m, k]);
let shape_rhs = Shape::new([batch_size, k, n]);

37 changes: 27 additions & 10 deletions crates/burn-wgpu/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
@@ -77,9 +77,14 @@ pub enum OutputInfo {
/// Write the local variable to a new array.
///
/// This will create a new binding in the [compute shader](ComputeShader).
Array { item: Item, local: u16 },
ArrayWrite { item: Item, local: u16 },
/// Write the local variable to an existing input binding.
Input { item: Item, input: u16, local: u16 },
InputArrayWrite { item: Item, input: u16, local: u16 },
/// Simply register the output, but don't automatically add a write to it.
///
/// Useful when a [procedure](gpu::Procedure) writes to the output using
/// [operations](gpu::Operation).
Array { item: Item },
}

impl Compilation {
@@ -125,8 +130,6 @@ impl Compilation {
named,
workgroup_size: settings.workgroup_size,
body: self.info.scope,
num_workgroups: true,
global_invocation_id: true,
}
}

@@ -174,7 +177,7 @@ impl Compilation {

for array in self.info.outputs.drain(..) {
match array {
OutputInfo::Array { item, local } => {
OutputInfo::ArrayWrite { item, local } => {
let item = item.vectorize(settings.vectorization);
let elem_adapted = bool_item(item);

@@ -190,14 +193,27 @@ impl Compilation {
);
index += 1;
}
OutputInfo::Input { item, input, local } => {
OutputInfo::InputArrayWrite { item, input, local } => {
let item = item.vectorize(settings.vectorization);

self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
);
}
OutputInfo::Array { item } => {
let item = item.vectorize(settings.vectorization);
let elem_adapted = bool_item(item);

self.output_bindings.push(Binding {
item: elem_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});

index += 1;
}
}
}
}
@@ -209,12 +225,13 @@ impl Compilation {
};

let (item, local) = match output {
OutputInfo::Array { item, local } => (item, local),
OutputInfo::Input {
OutputInfo::ArrayWrite { item, local } => (item, local),
OutputInfo::InputArrayWrite {
item: _,
input: _,
local: _,
} => return, // Output already updated.
} => return,
OutputInfo::Array { item: _ } => return,
};

let item = match self.input_bindings.get_mut(mapping.pos_input) {
@@ -235,7 +252,7 @@ impl Compilation {
};

// Update the output.
*output = OutputInfo::Input {
*output = OutputInfo::InputArrayWrite {
item,
input: mapping.pos_input as u16,
local: *local,
98 changes: 98 additions & 0 deletions crates/burn-wgpu/src/codegen/dialect/gpu/branch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use super::{Elem, Item, Scope, Variable};
use serde::{Deserialize, Serialize};

/// All branching types.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Branch {
// An if statement.
If(If),
// An if else statement.
IfElse(IfElse),
// A range loop.
RangeLoop(RangeLoop),
// A return statement.
Return,
// A break statement.
Break,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct If {
pub cond: Variable,
pub scope: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IfElse {
pub cond: Variable,
pub scope_if: Scope,
pub scope_else: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangeLoop {
pub i: Variable,
pub start: Variable,
pub end: Variable,
pub scope: Scope,
}

impl If {
/// Registers an if statement to the given scope.
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
let mut scope = parent_scope.child();

func(&mut scope);

let op = Self { cond, scope };
parent_scope.register(Branch::If(op));
}
}

impl IfElse {
/// Registers an if else statement to the given scope.
pub fn register<IF, ELSE>(
parent_scope: &mut Scope,
cond: Variable,
func_if: IF,
func_else: ELSE,
) where
IF: Fn(&mut Scope),
ELSE: Fn(&mut Scope),
{
let mut scope_if = parent_scope.child();
let mut scope_else = parent_scope.child();

func_if(&mut scope_if);
func_else(&mut scope_else);

parent_scope.register(Branch::IfElse(Self {
cond,
scope_if,
scope_else,
}));
}
}

impl RangeLoop {
/// Registers a range loop to the given scope.
pub fn register<F: Fn(Variable, &mut Scope)>(
parent_scope: &mut Scope,
start: Variable,
end: Variable,
func: F,
) {
let mut scope = parent_scope.child();
let index_ty = Item::Scalar(Elem::UInt);
let i = scope.create_local_undeclared(index_ty);

func(i, &mut scope);

parent_scope.register(Branch::RangeLoop(Self {
i,
start,
end,
scope,
}));
}
}
45 changes: 31 additions & 14 deletions crates/burn-wgpu/src/codegen/dialect/gpu/macros.rs
Original file line number Diff line number Diff line change
@@ -5,6 +5,10 @@ macro_rules! gpu {
($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
gpu!($scope, $out = add($lhs, $rhs))
};
// out += input
($scope:expr, $out:ident += $input:ident) => {
gpu!($scope, $out = add($out, $input))
};
// out = add(lhs, rhs)
($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Add(
@@ -121,6 +125,12 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out[lhs] = rhs
($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => {
$scope.register($crate::codegen::dialect::gpu::Operator::IndexAssign(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = |input|
($scope:expr, $out:ident = |$input:ident|) => {
gpu!($scope, $out = abs($input))
@@ -180,38 +190,45 @@ macro_rules! gpu {
));
};
// out = input
($scope:expr, eval $arg:expr) => {
gpu!($scope, $arg);
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
$scope.register($crate::codegen::dialect::gpu::Operator::Assign(
gpu!(unary $input, $out)
));
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
$scope.register($crate::codegen::dialect::gpu::Operator::Assign(
gpu!(unary $input, $out)
));
};
// out = shape(tensor, dim)
($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
$scope.register(Metadata::Shape {
dim: $dim,
var: $input,
out: $out,
$scope.register($crate::codegen::dialect::gpu::Metadata::Shape {
dim: $dim.into(),
var: $input.into(),
out: $out.into(),
});
};
// out = stride(tensor, dim)
($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
$scope.register(Metadata::Stride {
dim: $dim,
var: $input,
out: $out,
$scope.register($crate::codegen::dialect::gpu::Metadata::Stride {
dim: $dim.into(),
var: $input.into(),
out: $out.into(),
});
};
// range(start, end).for_each(|scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
};
// if (cond).then(|scope| { ... })
($scope:expr, if ($cond:expr).then($arg:expr)) => {
$crate::codegen::dialect::gpu::If::register($scope, $cond.into(), $arg);
};
// if (cond).then(|scope| { ... }).else(|scope| { ... })
($scope:expr, if ($cond:expr).then($arg_if:expr).else($arg_else:expr)) => {
$crate::codegen::dialect::gpu::IfElse::register($scope, $cond.into(), $arg_if, $arg_else);
};
(binary $lhs:expr, $rhs:expr, $out:expr) => {
$crate::codegen::dialect::gpu::BinaryOperator {
lhs: $lhs.into(),
7 changes: 5 additions & 2 deletions crates/burn-wgpu/src/codegen/dialect/gpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
pub(crate) mod algorithm;

mod branch;
mod macros;
mod operation;
mod procedure;
mod processing;
mod scope;
mod shader;
mod variable;
mod vectorization;

pub(crate) use branch::*;
pub(crate) use macros::gpu;
pub(crate) use operation::*;
pub(crate) use procedure::*;
pub(crate) use scope::*;
pub(crate) use shader::*;
pub(crate) use variable::*;
Loading
Oops, something went wrong.

0 comments on commit b256c04

Please sign in to comment.