Skip to content

Commit

Permalink
Refactor/jit fusion (tracel-ai#1750)
Browse files Browse the repository at this point in the history
* Reads & Writes with index_ref

* WIP

* Fix operations

* Cleanup
  • Loading branch information
nathanielsimard authored May 13, 2024
1 parent 022311d commit 9dcec0b
Show file tree
Hide file tree
Showing 19 changed files with 235 additions and 164 deletions.
48 changes: 40 additions & 8 deletions crates/burn-jit/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,16 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
OutputInfo::ArrayWrite { item, local: _ } => *item,
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => *item,
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => *item,
OutputInfo::Array { item } => *item,
}
Expand All @@ -298,9 +303,18 @@ pub enum OutputInfo {
/// Write the local variable to a new array.
///
/// This will create a new binding in the [compute shader](ComputeShader).
ArrayWrite { item: Item, local: u16 },
ArrayWrite {
item: Item,
local: u16,
position: Variable,
},
/// Write the local variable to an existing input binding.
InputArrayWrite { item: Item, input: u16, local: u16 },
InputArrayWrite {
item: Item,
input: u16,
local: u16,
position: Variable,
},
/// Simply register the output, but don't automatically add a write to it.
///
/// Useful when a [procedure](gpu::Procedure) writes to the output using
Expand All @@ -312,11 +326,16 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn elem_size<R: Runtime>(&self) -> usize {
let elem = match self {
OutputInfo::ArrayWrite { item, local: _ } => bool_elem(item.elem()),
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::Array { item } => bool_elem(item.elem()),
};
Expand Down Expand Up @@ -424,7 +443,11 @@ impl Compilation {

for array in self.info.outputs.drain(..) {
match array {
OutputInfo::ArrayWrite { item, local } => {
OutputInfo::ArrayWrite {
item,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
Expand All @@ -441,10 +464,16 @@ impl Compilation {
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalOutputArray(index, elem_adapted),
position,
);
index += 1;
}
OutputInfo::InputArrayWrite { item, input, local } => {
OutputInfo::InputArrayWrite {
item,
input,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
Expand All @@ -454,6 +483,7 @@ impl Compilation {
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
position,
);
}
OutputInfo::Array { item } => {
Expand Down Expand Up @@ -483,12 +513,13 @@ impl Compilation {
None => panic!("No output found."),
};

let (item, local) = match output {
OutputInfo::ArrayWrite { item, local } => (item, local),
let (item, local, position) = match output {
OutputInfo::ArrayWrite { item, local, position } => (item, local, position),
OutputInfo::InputArrayWrite {
item: _,
input,
local: _,
position: _,
} => {
assert_eq!(
*input, mapping.pos_input as u16,
Expand Down Expand Up @@ -521,6 +552,7 @@ impl Compilation {
item,
input: mapping.pos_input as u16,
local: *local,
position: *position,
};
}
}
Expand Down
23 changes: 17 additions & 6 deletions crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub struct ReadGlobal {
pub global: Variable,
/// The output variable to write the result.
pub out: Variable,
/// The reference position index.
pub position: Variable,
}

/// Read a global array with the given layout.
Expand All @@ -20,21 +22,24 @@ pub struct ReadGlobalWithLayout {
pub outs: Vec<Variable>,
/// The layout to be used.
pub layout: Variable,
/// The reference position index.
pub position: Variable,
}

impl ReadGlobal {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
scope.register(Operator::Index(BinaryOperator {
lhs: self.global,
rhs: Variable::Id,
rhs: self.position,
out: self.out,
}));
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
global: self.global.vectorize(vectorization),
out: self.out.vectorize(vectorization),
position: self.position,
}
}
}
Expand All @@ -47,6 +52,10 @@ impl ReadGlobalWithLayout {
return None;
}

if self.position != other.position {
return None;
}

let mut globals = Vec::with_capacity(self.globals.len() + other.globals.len());
globals.extend(&self.globals);
globals.extend(&other.globals);
Expand All @@ -59,6 +68,7 @@ impl ReadGlobalWithLayout {
globals,
outs,
layout: self.layout,
position: self.position,
})
}

Expand All @@ -75,7 +85,7 @@ impl ReadGlobalWithLayout {
tensors: tensors.clone(),
layout: self.layout,
indexes: indexes.clone(),
index_ref: Variable::Id,
position: self.position,
dim_start: 0u32.into(),
dim_end: Variable::Rank,
}
Expand Down Expand Up @@ -103,6 +113,7 @@ impl ReadGlobalWithLayout {
.iter()
.map(|o| o.vectorize(vectorization))
.collect(),
position: self.position,
}
}
}
Expand All @@ -117,10 +128,10 @@ pub struct IndexOffsetGlobalWithLayout {
pub indexes: Vec<Variable>,
/// Reference layout.
pub layout: Variable,
/// Index that corresponds to the reference layout.
/// Position index that corresponds to the reference layout.
///
/// All other indexes will be made to be compatible with this one.
pub index_ref: Variable,
pub position: Variable,
pub dim_start: Variable,
pub dim_end: Variable,
}
Expand All @@ -130,7 +141,7 @@ impl IndexOffsetGlobalWithLayout {
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let offset_ref = self.index_ref;
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: Variable = match self.tensors[0].item() {
Item::Vec4(_) => 4u32,
Expand Down Expand Up @@ -187,7 +198,7 @@ impl IndexOffsetGlobalWithLayout {
.map(|t| t.vectorize(vectorization))
.collect(),
layout: self.layout.vectorize(vectorization),
index_ref: self.index_ref.vectorize(vectorization),
position: self.position.vectorize(vectorization),
dim_start: self.dim_start,
dim_end: self.dim_end,
}
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ use serde::{Deserialize, Serialize};
pub struct WriteGlobal {
pub input: Variable,
pub global: Variable,
pub position: Variable,
}

impl WriteGlobal {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let output = self.global;
let input = self.input;
let position = Variable::Id;
let position = self.position;

gpu!(scope, output[position] = input);
}
Expand All @@ -23,6 +24,7 @@ impl WriteGlobal {
Self {
input: self.input.vectorize(vectorization),
global: self.global.vectorize(vectorization),
position: self.position,
}
}
}
53 changes: 33 additions & 20 deletions crates/burn-jit/src/codegen/dialect/gpu/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ pub struct Scope {
locals: Vec<Variable>,
shared_memories: Vec<Variable>,
local_arrays: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
index_offset_with_output_layout_position: Vec<usize>,
writes_global: Vec<(Variable, Variable)>,
writes_global: Vec<(Variable, Variable, Variable)>,
reads_scalar: Vec<(Variable, Variable)>,
pub layout_ref: Option<Variable>,
undeclared: u16,
Expand Down Expand Up @@ -102,8 +102,13 @@ impl Scope {
/// Reads an input array to a local variable.
///
/// The index refers to the argument position of the array in the compute shader.
pub(crate) fn read_array<I: Into<Item>>(&mut self, index: u16, item: I) -> Variable {
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout)
pub(crate) fn read_array<I: Into<Item>>(
&mut self,
index: u16,
item: I,
position: Variable,
) -> Variable {
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout, position)
}

/// Add the procedure into the scope.
Expand Down Expand Up @@ -143,27 +148,31 @@ impl Scope {
self.locals
.iter_mut()
.for_each(|var| *var = var.vectorize(vectorization));
self.reads_global.iter_mut().for_each(|(input, _, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.writes_global.iter_mut().for_each(|(input, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.reads_global
.iter_mut()
.for_each(|(input, _, output, _position)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.writes_global
.iter_mut()
.for_each(|(input, output, _)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
}

/// Writes a variable to given output.
///
/// Notes:
///
/// This should only be used when doing compilation.
pub(crate) fn write_global(&mut self, input: Variable, output: Variable) {
pub(crate) fn write_global(&mut self, input: Variable, output: Variable, position: Variable) {
// This assumes that all outputs have the same layout
if self.layout_ref.is_none() {
self.layout_ref = Some(output);
}
self.writes_global.push((input, output));
self.writes_global.push((input, output, position));
}

/// Writes a variable to given output.
Expand All @@ -184,10 +193,10 @@ impl Scope {
///
/// This should only be used when doing compilation.
pub(crate) fn update_read(&mut self, index: u16, strategy: ReadingStrategy) {
if let Some((_, strategy_old, _)) = self
if let Some((_, strategy_old, _, _position)) = self
.reads_global
.iter_mut()
.find(|(var, _, _)| var.index() == Some(index))
.find(|(var, _, _, _)| var.index() == Some(index))
{
*strategy_old = strategy;
}
Expand All @@ -197,7 +206,7 @@ impl Scope {
pub(crate) fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
self.reads_global
.iter()
.map(|(var, strategy, _)| match var {
.map(|(var, strategy, _, _)| match var {
Variable::GlobalInputArray(id, _) => (*id, *strategy),
_ => panic!("Can only read global input arrays."),
})
Expand Down Expand Up @@ -250,7 +259,7 @@ impl Scope {

let mut operations = Vec::new();

for (input, strategy, local) in self.reads_global.drain(..) {
for (input, strategy, local, position) in self.reads_global.drain(..) {
match strategy {
ReadingStrategy::OutputLayout => {
let output = self.layout_ref.expect(
Expand All @@ -261,13 +270,15 @@ impl Scope {
globals: vec![input],
layout: output,
outs: vec![local],
position,
},
)));
}
ReadingStrategy::Plain => {
operations.push(Operation::Procedure(Procedure::ReadGlobal(ReadGlobal {
global: input,
out: local,
position,
})))
}
}
Expand All @@ -288,10 +299,11 @@ impl Scope {
operations.push(op);
}

for (input, global) in self.writes_global.drain(..) {
for (input, global, position) in self.writes_global.drain(..) {
operations.push(Operation::Procedure(Procedure::WriteGlobal(WriteGlobal {
input,
global,
position,
})))
}

Expand Down Expand Up @@ -323,6 +335,7 @@ impl Scope {
index: u16,
item: Item,
strategy: ReadingStrategy,
position: Variable,
) -> Variable {
let item_global = match item.elem() {
Elem::Bool => match item {
Expand All @@ -336,7 +349,7 @@ impl Scope {
let input = Variable::GlobalInputArray(index, item_global);
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
self.reads_global.push((input, strategy, local));
self.reads_global.push((input, strategy, local, position));
self.locals.push(local);
local
}
Expand Down
Loading

0 comments on commit 9dcec0b

Please sign in to comment.