Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/cube/expand & fix double imports #2009

Merged
merged 8 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion crates/burn-cube-macros/src/codegen_common/signature.rs
Original file line number Diff line number Diff line change
@@ -2,10 +2,17 @@ use quote::ToTokens;

use crate::tracker::VariableTracker;

#[derive(Copy, Clone, Debug)]
pub enum ExpandMode {
FuncImpl,
MethodImpl,
}

pub fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
mut variable_tracker: Option<&mut VariableTracker>,
mode: ExpandMode,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();

@@ -42,7 +49,10 @@ pub fn expand_sig(
}

let ident = &sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = match mode {
ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()),
_ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()),
};

let generics = sig.generics.clone().into_token_stream();

30 changes: 25 additions & 5 deletions crates/burn-cube-macros/src/codegen_function/function.rs
Original file line number Diff line number Diff line change
@@ -106,22 +106,42 @@ pub(crate) fn codegen_call(
// Path
let mut path_tokens = TokenStream::new();
let mut is_comptime = false;
let mut is_plain_func = true;
let mut comptime_func: Option<String> = None;

for (i, (ident, generics)) in path.iter().enumerate() {
if *ident == "Comptime" {
let name = ident.to_string();

if name == "Comptime" {
is_comptime = true;
continue;
}

if let Some(first_char) = name.chars().next() {
if first_char.is_uppercase() {
is_plain_func = false;
}
}

if i == path.len() - 1 {
if is_comptime {
comptime_func = Some(ident.to_string());
break;
}
let func_name_expand = syn::Ident::new(
format!("{ident}_expand").as_str(),
proc_macro2::Span::call_site(),
);

let func_name_expand = if is_plain_func {
quote::quote! {
#ident::__expand
}
} else {
let ident = syn::Ident::new(
format!("__expand_{ident}").as_str(),
proc_macro2::Span::call_site(),
);
quote::quote! {
#ident
}
};
path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand });
if let Some(generics) = generics {
path_tokens.extend(quote_spanned! {generics.span() => #generics });
17 changes: 11 additions & 6 deletions crates/burn-cube-macros/src/codegen_function/launch.rs
Original file line number Diff line number Diff line change
@@ -211,7 +211,7 @@ impl Codegen {
}
}

fn gen_define_impl(&self, expand: &Ident) -> TokenStream {
fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream {
let mut expand_args = quote::quote! { &mut builder.context, };

let mut variables = quote::quote! {};
@@ -340,7 +340,7 @@ impl Codegen {
tokens
}

fn gen_compile_impl(&self, expand: &Ident) -> TokenStream {
fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream {
let ident = Ident::new(&self.name, Span::call_site());
let generics = add_runtime(self.generics.clone());
let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
@@ -453,22 +453,27 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
let codegen = Codegen::from_sig(sig);

let ident = &sig.ident;
let ident_expand = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = syn::Ident::new(format!("{ident}_launch").as_str(), ident.span());

let ident_expand = quote::quote! {
__expand
};

let generics = add_runtime(add_lifetime(sig.generics.clone()));
let body = codegen.gen_launch_body();
let kernel = codegen.gen_kernel_struct();
let compile = codegen.gen_compile_impl(&ident_expand);
let (inputs, output) = (codegen.fn_inputs, codegen.fn_output);
let doc =
format!("Launch the kernel [{ident}] with the provided argument on the given runtime.");

quote::quote! {
#kernel
#compile

#[allow(clippy::too_many_arguments)]
/// Launch
pub fn #ident #generics (
#[doc = #doc]
/// Launch the kernel.
pub fn launch #generics (
client: ComputeClient<R::Server, R::Channel>,
cube_count: CubeCount<R::Server>,
cube_dim: CubeDim,
24 changes: 18 additions & 6 deletions crates/burn-cube-macros/src/codegen_trait/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
use crate::codegen_common::signature::expand_sig;
use proc_macro2::TokenStream;

use crate::codegen_common::signature::{expand_sig, ExpandMode};

pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();

for item in tr.items.iter() {
match item {
syn::TraitItem::Fn(func) => {
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);
expand_items.push(syn::parse_quote!(#expand;));
}
_ => continue,
@@ -26,7 +33,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
match item {
syn::ImplItem::Fn(func) => {
let ident = &func.sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = quote::quote! {#ident::__expand};
let mut inputs = quote::quote!();

for input in &func.sig.inputs {
@@ -41,7 +48,12 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
}
}

let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);

let tokens = if !tr.generics.params.is_empty() {
let mut func = func.clone();
@@ -67,7 +79,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {

fn register_expand(
func: &syn::ImplItemFn,
name: &syn::Ident,
name: &TokenStream,
expand: proc_macro2::TokenStream,
inputs: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
@@ -91,7 +103,7 @@ fn register_expand(
quote::quote! (
#expand {
#[cube]
#func
pub #func
#func_expand
}
)
56 changes: 36 additions & 20 deletions crates/burn-cube-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ mod tracker;
pub(crate) mod codegen_common;

use analyzer::VariableAnalyzer;
use codegen_common::signature::expand_sig;
use codegen_common::signature::{expand_sig, ExpandMode};
use codegen_function::{codegen_launch, codegen_statement};
use codegen_trait::{expand_trait_def, expand_trait_impl};
use codegen_type::generate_cube_type;
@@ -69,20 +69,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);

match codegen_cube(&func, &mut variable_tracker) {
Ok(code) => {
if attrs.launch {
let launch = codegen_launch(&func.sig);

quote::quote! {
#code
#launch
}
.into()
} else {
code.into()
}
}
match codegen_cube(&func, &mut variable_tracker, attrs.launch) {
Ok(code) => code.into(),
Err(err) => err.into(),
}
}
@@ -120,8 +108,15 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
launch: bool,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
let signature = expand_sig(
&func.sig,
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
// it from an outside module.
Some(variable_tracker),
ExpandMode::FuncImpl,
);
let mut body = quote::quote! {};

for statement in func.block.stmts.iter() {
@@ -148,15 +143,36 @@ fn codegen_cube(
return Err(code);
}

let launch_doc = if launch { "and launch function " } else { "" };

let launch = if launch {
codegen_launch(&func.sig)
} else {
quote::quote! {}
};

let mod_name = &func.sig.ident;
let vis = &func.vis;
let doc = format!("Module containing the expand method {launch_doc}of {mod_name}.");

Ok(quote::quote! {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
#func

#[allow(unused_mut)]
#[allow(clippy::too_many_arguments)]
#signature {
#body

#[doc = #doc]
#vis mod #mod_name {
use super::*;

#launch

#[allow(unused_mut)]
#[allow(clippy::too_many_arguments)]
#signature {
#body
}

}
})
}
Loading
Oops, something went wrong.
Loading