-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float cast op for JIT backend #2511
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2511 +/- ##
==========================================
- Coverage 82.69% 82.53% -0.17%
==========================================
Files 826 827 +1
Lines 106405 106689 +284
==========================================
+ Hits 87996 88057 +61
- Misses 18409 18632 +223 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
pub trait Primitive: Clone + Send + Sync + core::fmt::Debug { | ||
/// The dtype of the tensor. | ||
fn dtype(&self) -> DType; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be a good idea to put all tensor metadata here, so to include shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I only needed the dtype for the current changes but it makes sense to expand!
/edit: if we add that to the primitive trait, we could probably remove the *_shape
ops for different tensor types and have a global implementation that uses the primitive shape.
Checklist
run-checks all
script has been executed.Related Issues/PRs
Follow-up of #2483
Changes
Added
float_cast
implementation forburn-jit
,burn-fusion
andburn-router
.JitTensor
JitTensor
now contains the dtype fieldexecute_with_dtype
macro to call cubecl kernels)Primitive
trait for backend tensor primitivesDisplay
implementation forTensor
(K::elem_type_name()
would report the default floating point dtype even for tensors cast to another precision)Known issue
Trying to concatenate tensors cast to a different floating point dtype will fail with the current
cat_with_slice_assign
implementation because the output tensor is initialized with the default floating point element precision inK::empty(...)
.TODO
float_cast
for ndarrayTesting
Added new tests for primitive trait.