Skip to content

Commit

Permalink
Feat/fusion/cmp (tracel-ai#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Nov 23, 2023
1 parent b86bc58 commit 58273a8
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 119 deletions.
15 changes: 15 additions & 0 deletions burn-wgpu/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ where
fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self];
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem;
}

/// The float element type for the wgpu backend.
Expand All @@ -27,6 +29,10 @@ impl WgpuElement for u32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::U32
}
}

impl WgpuElement for i32 {
Expand All @@ -39,6 +45,10 @@ impl WgpuElement for i32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::I32
}
}

impl WgpuElement for f32 {
Expand All @@ -51,6 +61,11 @@ impl WgpuElement for f32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}

#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::F32
}
}

impl FloatElement for f32 {}
Expand Down
84 changes: 76 additions & 8 deletions burn-wgpu/src/fusion/codegen/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,37 @@ pub enum Operator {
input: Variable,
out: Variable,
},
Equal {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Lower {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Greater {
lhs: Variable,
rhs: Variable,
out: Variable,
},
LowerEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
GreaterEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
ConditionalAssign {
cond: Variable,
lhs: Variable,
rhs: Variable,
out: Variable,
},
AssignGlobal {
input: Variable,
out: Variable,
Expand Down Expand Up @@ -109,22 +140,41 @@ impl Display for Operator {
Operator::Recip { input, out } => {
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
}
Operator::Equal { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} == {rhs};"))
}
Operator::Lower { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} < {rhs};"))
}
Operator::Greater { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} > {rhs};"))
}
Operator::LowerEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} <= {rhs};"))
}
Operator::GreaterEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} >= {rhs};"))
}
Operator::AssignGlobal { input, out } => {
f.write_fmt(format_args!("{out}_global[id] = {input};"))
let elem = out.elem();
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
}
Operator::ReadGlobal {
variable,
position,
position_out,
} => {
let (global, local) = match variable {
Variable::Input(number) => {
(format!("input_{number}_global"), format!("input_{number}"))
}
Variable::Local(_) => panic!("can't read globala local variable."),
Variable::Output(number) => (
let (global, local, elem) = match variable {
Variable::Input(number, elem) => (
format!("input_{number}_global"),
format!("input_{number}"),
elem,
),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, elem) => (
format!("output_{number}_global"),
format!("output_{number}"),
elem,
),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
};
Expand All @@ -144,7 +194,25 @@ for (var i: u32 = 1u; i <= rank; i++) {{
index_{local} += id / stride_out % shape * stride;
}}
let {local} = {global}[index_{local}];
let {local} = {elem}({global}[index_{local}]);
"
))
}
Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
let elem = out.elem();
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond} {{
{out} = {lhs};
}} else {{
{out} = {rhs};
}}
"
))
}
Expand Down
4 changes: 3 additions & 1 deletion burn-wgpu/src/fusion/codegen/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ pub enum Visibility {
ReadWrite,
}

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
#[allow(dead_code)]
I32,
U32,
Bool,
}

#[derive(Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -187,6 +188,7 @@ impl Display for Elem {
Elem::F32 => f.write_str("f32"),
Elem::I32 => f.write_str("i32"),
Elem::U32 => f.write_str("u32"),
Elem::Bool => f.write_str("bool"),
}
}
}
Expand Down
23 changes: 17 additions & 6 deletions burn-wgpu/src/fusion/codegen/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@ use std::fmt::Display;

#[derive(Debug, Hash, Clone)]
pub enum Variable {
Input(u16),
Input(u16, Elem),
Scalar(u16, Elem),
Local(u16),
Output(u16),
Local(u16, Elem),
Output(u16, Elem),
}

impl Variable {
pub fn elem(&self) -> &Elem {
match self {
Variable::Input(_, e) => e,
Variable::Scalar(_, e) => e,
Variable::Local(_, e) => e,
Variable::Output(_, e) => e,
}
}
}

impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
}
}
Expand Down
Loading

0 comments on commit 58273a8

Please sign in to comment.