-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more type support for burn-jit #2454
Conversation
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), | ||
SequenceArg::new(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if there is a better way to do this 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh I don't understand fusion at all, so I just copied what was there for all the new types 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's simply that each precision has its state. It would be cleaner to have a map Map<ElemenType, State>
, instead of listing all posibilities. I'll think how I can refactor it, but not in this PR.
use super::*; | ||
use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; | ||
|
||
pub type TestBackend = burn_fusion::Fusion<JitBackend<TestRuntime, f32, i32>>; | ||
pub type TestBackend2<F, I> = burn_fusion::Fusion<JitBackend<TestRuntime, F, I>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestBackend2 is the one that can be injected with the precision right? Not sure if we could come up with a better name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't come up with a better name and it's only used in the macro to instantiate and override TestBackend
with the correct types, so I left it for now. It's not used in the actual tests, just the macro.
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Implements the new types and parameterized testing from tracel-ai/cubecl#207
Changes
Adds support for
flex32
,i8
,i16
,i64
,u8
,u16
,u64
support to burn, parameterize tests to ensure all types work correctlyTesting
New parameterized tests