Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CubeCL first iteration #1756

Merged
merged 60 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2142730
WIP
nathanielsimard Apr 9, 2024
94f476f
WIP
nathanielsimard Apr 10, 2024
8e4d39e
WIP
nathanielsimard Apr 10, 2024
85a0d83
Merge branch 'main' into feat/cube-pl
nathanielsimard Apr 11, 2024
f15919f
WIP
nathanielsimard Apr 12, 2024
3c64001
Push
nathanielsimard Apr 16, 2024
8c82ea8
Wip
nathanielsimard Apr 16, 2024
8ed5217
little refactor
louisfd Apr 17, 2024
dbbfdb7
wip
louisfd Apr 18, 2024
797d997
the right number of clones
louisfd Apr 19, 2024
2617e6b
add comments
louisfd Apr 19, 2024
16b454a
wip
louisfd Apr 22, 2024
c024e73
merge main
louisfd Apr 26, 2024
d756412
for loop tests
louisfd Apr 29, 2024
faea9d7
wip
louisfd Apr 29, 2024
709522e
refactor identify leaves and deletables
louisfd Apr 29, 2024
7f387f4
Merge branch 'main' of github.com:tracel-ai/burn
louisfd Apr 30, 2024
51b6085
Merge branch 'main' into feat/cube-pl
louisfd Apr 30, 2024
32bfcae
merge main
louisfd Apr 30, 2024
7a25525
wip
louisfd May 1, 2024
99fe8d8
prelude working
louisfd May 2, 2024
ad83d02
fix for loop
louisfd May 2, 2024
9547998
wip
louisfd May 2, 2024
bcd3c86
if
louisfd May 3, 2024
9cfafd9
if using variable inside
louisfd May 3, 2024
9310451
while loop
louisfd May 3, 2024
08f0f6e
loop and break
louisfd May 3, 2024
fb65985
assign add
louisfd May 3, 2024
8fd3541
wip
louisfd May 6, 2024
a14b653
variable reuse
louisfd May 6, 2024
fab9201
cast elem
louisfd May 6, 2024
b9bac61
wip cast
louisfd May 8, 2024
6aa7ccc
cast kind float
louisfd May 8, 2024
16e38ac
refactor elements
louisfd May 9, 2024
4f7c56f
make tests work
louisfd May 9, 2024
6375736
cast kind done
louisfd May 9, 2024
0c50975
rename gpu macro
louisfd May 9, 2024
da92893
refactor
louisfd May 10, 2024
2047e47
T::new
louisfd May 10, 2024
3292198
rename cast
louisfd May 10, 2024
9046baa
type system becoming great
louisfd May 11, 2024
8266bae
merge main + dirty fix for F64
louisfd May 11, 2024
9c347da
fmt
louisfd May 11, 2024
2cf7603
clippy
louisfd May 11, 2024
3b0052f
found the culprit
louisfd May 11, 2024
6a42148
typo
louisfd May 11, 2024
9a4a94c
add doc
louisfd May 13, 2024
80312ac
refactor codegen into files
louisfd May 13, 2024
a2b0369
minor refactor
louisfd May 13, 2024
d342d7e
prevent clippy from breaking tests
louisfd May 13, 2024
fa648e4
add doc
louisfd May 13, 2024
a2b63a3
more clean
louisfd May 13, 2024
f3341e9
fix expand outputs
louisfd May 13, 2024
cd183e9
traits, modules and parenthesis
louisfd May 13, 2024
652c68a
fmt
louisfd May 13, 2024
040bbe0
clippy
louisfd May 13, 2024
73ecc19
fmt again?
louisfd May 13, 2024
2e348f1
ops not compiling
louisfd May 15, 2024
c7a2454
make uint numeric
louisfd May 15, 2024
2bf7db9
rename new/constant to lit
louisfd May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename gpu macro
  • Loading branch information
louisfd committed May 9, 2024
commit 0c509759bad452cf33df465161b4c243dc35f817
2 changes: 0 additions & 2 deletions crates/burn-cube-macros/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub(crate) fn get_prelude(needed_functions: &HashSet<VariableKey>) -> proc_macro

fn codegen_float_new() -> proc_macro2::TokenStream {
quote::quote! {
// TODO ENCAPSULATE IMPORTS

pub fn float_new<F: burn_cube::Float>(val: f32) -> F {
F::new(val, 1)
}
Expand Down
50 changes: 26 additions & 24 deletions crates/burn-cube/tests/cast_elem.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt, F32, I32};
use burn_jit::gpu;
use burn_jit::gpu::{Elem, Item, Variable};
use burn_jit::{
cube_inline,
gpu::{Elem, Item, Variable},
};

macro_rules! cast_test {
($name:ident, $module:ident, $from:expr, $to:expr) => {
Expand All @@ -15,7 +17,7 @@ macro_rules! cast_test {

assert_eq!(
format!("{:?}", scope.operations),
gpu_macro_ref_cast($from, $to)
inline_macro_ref_cast($from, $to)
);
}
};
Expand All @@ -32,7 +34,7 @@ macro_rules! cast_test {

assert_eq!(
format!("{:?}", scope.operations),
gpu_macro_ref_identity($ty)
inline_macro_ref_identity($ty)
);
}
};
Expand Down Expand Up @@ -246,7 +248,7 @@ cast_test!(
Item::Scalar(Elem::Bool)
);

fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String {
fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String {
let mut context = CubeContext::root();
let x = context.create_local(from_item);

Expand All @@ -257,25 +259,25 @@ fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String {
let z = scope.create_local(to_item);

match from_item.elem() {
Elem::Float(_) => gpu!(scope, y = x + 2f32),
Elem::Int(_) => gpu!(scope, y = x + 2i32),
Elem::UInt => gpu!(scope, y = x + 2u32),
Elem::Bool => gpu!(scope, y = x && false),
Elem::Float(_) => cube_inline!(scope, y = x + 2f32),
Elem::Int(_) => cube_inline!(scope, y = x + 2i32),
Elem::UInt => cube_inline!(scope, y = x + 2u32),
Elem::Bool => cube_inline!(scope, y = x && false),
}

gpu!(scope, y_casted = cast(y));
cube_inline!(scope, y_casted = cast(y));

match to_item.elem() {
Elem::Float(_) => gpu!(scope, z = y_casted + 34f32),
Elem::Int(_) => gpu!(scope, z = y_casted + 34i32),
Elem::UInt => gpu!(scope, z = y_casted + 34u32),
Elem::Bool => gpu!(scope, z = y_casted || true),
Elem::Float(_) => cube_inline!(scope, z = y_casted + 34f32),
Elem::Int(_) => cube_inline!(scope, z = y_casted + 34i32),
Elem::UInt => cube_inline!(scope, z = y_casted + 34u32),
Elem::Bool => cube_inline!(scope, z = y_casted || true),
}

format!("{:?}", scope.operations)
}

fn gpu_macro_ref_identity(item: Item) -> String {
fn inline_macro_ref_identity(item: Item) -> String {
// When staying with the same type variables are automatically reused in cube
let mut context = CubeContext::root();
let x = context.create_local(item);
Expand All @@ -285,19 +287,19 @@ fn gpu_macro_ref_identity(item: Item) -> String {
let y = scope.create_local(item);

match item.elem() {
Elem::Float(_) => gpu!(scope, y = x + 2f32),
Elem::Int(_) => gpu!(scope, y = x + 2i32),
Elem::UInt => gpu!(scope, y = x + 2u32),
Elem::Bool => gpu!(scope, y = x && false),
Elem::Float(_) => cube_inline!(scope, y = x + 2f32),
Elem::Int(_) => cube_inline!(scope, y = x + 2i32),
Elem::UInt => cube_inline!(scope, y = x + 2u32),
Elem::Bool => cube_inline!(scope, y = x && false),
}

gpu!(scope, x = cast(y));
cube_inline!(scope, x = cast(y));

match item.elem() {
Elem::Float(_) => gpu!(scope, y = x + 34f32),
Elem::Int(_) => gpu!(scope, y = x + 34i32),
Elem::UInt => gpu!(scope, y = x + 34u32),
Elem::Bool => gpu!(scope, y = x || true),
Elem::Float(_) => cube_inline!(scope, y = x + 34f32),
Elem::Int(_) => cube_inline!(scope, y = x + 34i32),
Elem::UInt => cube_inline!(scope, y = x + 34u32),
Elem::Bool => cube_inline!(scope, y = x || true),
}

format!("{:?}", scope.operations)
Expand Down
27 changes: 14 additions & 13 deletions crates/burn-cube/tests/cast_kind.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use burn_cube::{cube, CubeContext, Float, Int, F32, F64, I32, I64};
use burn_jit::gpu;
// use burn_jit::gpu::FloatKind;
use burn_jit::gpu::{Elem, Item};
use burn_jit::{
cube_inline,
gpu::{Elem, Item},
};

#[cube]
pub fn cast_float_kind<F1: Float, F2: Float>(input: F1) {
Expand All @@ -28,7 +29,7 @@ fn cube_cast_float_kind_test() {
cast_float_kind::expand::<F64, F32>(&mut context, input);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_float());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float());
}

#[test]
Expand All @@ -41,10 +42,10 @@ fn cube_cast_int_kind_test() {
cast_int_kind::expand::<I32, I64>(&mut context, input);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_int());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
}

fn gpu_macro_ref_float() -> String {
fn inline_macro_ref_float() -> String {
let mut context = CubeContext::root();
let float_64 = Item::Scalar(Elem::Float(F64::into_kind()));
let float_32 = Item::Scalar(Elem::Float(F32::into_kind()));
Expand All @@ -55,14 +56,14 @@ fn gpu_macro_ref_float() -> String {
let y = scope.create_local(float_32);
let z = scope.create_local(float_32);

gpu!(scope, x = input + 5.9f32 as f64);
gpu!(scope, y = cast(x));
gpu!(scope, z = y + 2.3f32);
cube_inline!(scope, x = input + 5.9f32 as f64);
cube_inline!(scope, y = cast(x));
cube_inline!(scope, z = y + 2.3f32);

format!("{:?}", scope.operations)
}

fn gpu_macro_ref_int() -> String {
fn inline_macro_ref_int() -> String {
let mut context = CubeContext::root();
let int_32 = Item::Scalar(Elem::Int(I32::into_kind()));
let int_64 = Item::Scalar(Elem::Int(I64::into_kind()));
Expand All @@ -73,9 +74,9 @@ fn gpu_macro_ref_int() -> String {
let y = scope.create_local(int_64);
let z = scope.create_local(int_64);

gpu!(scope, x = input + 5i32);
gpu!(scope, y = cast(x));
gpu!(scope, z = y + 2i64);
cube_inline!(scope, x = input + 5i32);
cube_inline!(scope, y = cast(x));
cube_inline!(scope, z = y + 2i64);

format!("{:?}", scope.operations)
}
24 changes: 13 additions & 11 deletions crates/burn-cube/tests/for_loop.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt, F32};
use burn_jit::gpu;
use burn_jit::gpu::{Elem, Item, Variable};
use burn_jit::{
cube_inline,
gpu::{Elem, Item, Variable},
};

type ElemType = F32;

Expand All @@ -26,7 +28,7 @@ fn test_for_loop_with_unroll() {
for_loop::expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll));
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));
}

#[test]
Expand All @@ -41,10 +43,10 @@ fn test_for_loop_no_unroll() {
for_loop::expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll));
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));
}

fn gpu_macro_ref(unroll: bool) -> String {
fn inline_macro_ref(unroll: bool) -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(Elem::Float(ElemType::into_kind()));

Expand All @@ -58,15 +60,15 @@ fn gpu_macro_ref(unroll: bool) -> String {
// Kernel
let tmp1 = scope.create_local(item);
let tmp2 = scope.create_local(item);
gpu!(scope, tmp1 = rhs * rhs);
gpu!(scope, tmp2 = tmp1 + rhs);
cube_inline!(scope, tmp1 = rhs * rhs);
cube_inline!(scope, tmp2 = tmp1 + rhs);

gpu!(
cube_inline!(
&mut scope,
range(0u32, end, unroll).for_each(|i, scope| {
gpu!(scope, rhs = lhs[i]);
gpu!(scope, tmp1 = tmp2 + rhs);
gpu!(scope, lhs[i] = tmp1);
cube_inline!(scope, rhs = lhs[i]);
cube_inline!(scope, tmp1 = tmp2 + rhs);
cube_inline!(scope, lhs[i] = tmp1);
})
);

Expand Down
13 changes: 6 additions & 7 deletions crates/burn-cube/tests/if.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use burn_cube::{cube, if_expand, CubeContext, Float, F32};
use burn_jit::gpu;
use burn_jit::gpu::{Elem, Item, Variable};
use burn_jit::{cube_inline, gpu::{Elem, Item, Variable}};

type ElemType = F32;

Expand All @@ -20,10 +19,10 @@ fn cube_if_test() {
if_greater::expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

fn gpu_macro_ref() -> String {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(Elem::Float(ElemType::into_kind()));
let lhs = context.create_local(item);
Expand All @@ -33,9 +32,9 @@ fn gpu_macro_ref() -> String {
let lhs: Variable = lhs.into();
let y = scope.create_local(item);

gpu!(scope, cond = lhs > 0f32);
gpu!(&mut scope, if(cond).then(|scope| {
gpu!(scope, y = lhs + 4.0f32);
cube_inline!(scope, cond = lhs > 0f32);
cube_inline!(&mut scope, if(cond).then(|scope| {
cube_inline!(scope, y = lhs + 4.0f32);
}));

format!("{:?}", scope.operations)
Expand Down
18 changes: 10 additions & 8 deletions crates/burn-cube/tests/if_else.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use burn_cube::{cube, if_else_expand, CubeContext, Float, F32};
use burn_jit::gpu;
use burn_jit::gpu::{Elem, Item, Variable};
use burn_jit::{
cube_inline,
gpu::{Elem, Item, Variable},
};

type ElemType = F32;

Expand All @@ -22,10 +24,10 @@ fn cube_if_else_test() {
if_else::expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

fn gpu_macro_ref() -> String {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(Elem::Float(ElemType::into_kind()));
let lhs = context.create_local(item);
Expand All @@ -35,11 +37,11 @@ fn gpu_macro_ref() -> String {
let lhs: Variable = lhs.into();
let y = scope.create_local(item);

gpu!(scope, cond = lhs < 0f32);
gpu!(&mut scope, if(cond).then(|scope| {
gpu!(scope, y = lhs + 4.0f32);
cube_inline!(scope, cond = lhs < 0f32);
cube_inline!(&mut scope, if(cond).then(|scope| {
cube_inline!(scope, y = lhs + 4.0f32);
}).else(|scope|{
gpu!(scope, y = lhs - 5.0f32);
cube_inline!(scope, y = lhs - 5.0f32);
}));

format!("{:?}", scope.operations)
Expand Down
9 changes: 4 additions & 5 deletions crates/burn-cube/tests/literal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use burn_cube::{cube, CubeContext, Float, F32};
use burn_jit::gpu;
use burn_jit::gpu::{Elem, Item};
use burn_jit::{cube_inline, gpu::{Elem, Item}};

type ElemType = F32;

Expand All @@ -18,17 +17,17 @@ fn cube_literal_test() {
literal::expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

fn gpu_macro_ref() -> String {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(Elem::Float(ElemType::into_kind()));
let lhs = context.create_local(item);

let mut scope = context.into_scope();
let out = scope.create_local(item);
gpu!(scope, out = lhs + 5.9f32);
cube_inline!(scope, out = lhs + 5.9f32);

format!("{:?}", scope.operations)
}
16 changes: 8 additions & 8 deletions crates/burn-cube/tests/loop.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn_cube::{
break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int, I32,
};
use burn_jit::gpu;
use burn_jit::cube_inline;
use burn_jit::gpu::Branch;
use burn_jit::gpu::{Elem, Item, Variable};

Expand Down Expand Up @@ -33,7 +33,7 @@ fn cube_while_test() {
while_not::expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

#[test]
Expand All @@ -45,10 +45,10 @@ fn cube_loop_break_test() {
manual_loop_break::expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

fn gpu_macro_ref() -> String {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(Elem::Int(ElemType::into_kind()));
let lhs = context.create_local(item);
Expand All @@ -58,15 +58,15 @@ fn gpu_macro_ref() -> String {
let lhs: Variable = lhs.into();
let rhs = scope.create_local(item);

gpu!(
cube_inline!(
&mut scope,
loop(|scope| {
gpu!(scope, cond = lhs != 0);
gpu!(scope, if(cond).then(|scope|{
cube_inline!(scope, cond = lhs != 0);
cube_inline!(scope, if(cond).then(|scope|{
scope.register(Branch::Break);
}));

gpu!(scope, rhs = lhs - 1i32);
cube_inline!(scope, rhs = lhs - 1i32);
})
);

Expand Down
Loading