-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove precision bridge and complete float_cast
implementation for ndarray
#2538
Conversation
- Change primitive to `NdArrayTensorFloat` enum - Use `execute_with_float_dtype` macro to handle ops w/ generics based on the float enum variant
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2538 +/- ##
========================================
Coverage 82.56% 82.56%
========================================
Files 827 822 -5
Lines 106855 106964 +109
========================================
+ Hits 88226 88320 +94
- Misses 18629 18644 +15 ☔ View full report in Codecov by Sentry. |
crates/burn-ndarray/src/tensor.rs
Outdated
// pub(crate) fn cast_to<O: Element>(self) -> NdArrayTensor<O> { | ||
// let array = self.array.mapv(|a| a.elem()).into_shared(); | ||
|
||
// NdArrayTensor::new(array) | ||
// } | ||
|
||
// pub(crate) fn to_float(self) -> NdArrayTensorFloat { | ||
// match E::dtype() { | ||
// DType::F64 => NdArrayTensorFloat::F64(self.cast_to()), | ||
// DType::F32 => NdArrayTensorFloat::F32(self.cast_to()), | ||
// _ => unimplemented!("Invalid float type"), | ||
// } | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code detector on point, as always 😄
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2483 #2511
Changes
tensor.cast(dtype)
tensor.dtype()
method (helpful to perform to/from a type, e.g. full precision usage)float_cast
for ndarray backendNdArrayTensorFloat
enumexecute_with_float_dtype
macro which handles ops w/ generics based on the float enum variantTesting
Cast unit test + float ops unit tests