Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float cast op for JIT backend #2511

Merged
merged 20 commits into from
Nov 21, 2024
Merged

Add float cast op for JIT backend #2511

merged 20 commits into from
Nov 21, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Nov 18, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Follow-up of #2483

Changes

Added float_cast implementation for burn-jit, burn-fusion and burn-router.

  • Removed elem type generic from JitTensor
    • JitTensor now contains the dtype field
    • Refactored all calls with generics to have the correct generic types propagation (new execute_with_dtype macro to call cubecl kernels)
  • Added Primitive trait for backend tensor primitives
    • Motivation: all primitives previously had the same trait requirements, and now we need to report the correct dtype in the Display implementation for Tensor (K::elem_type_name() would report the default floating point dtype even for tensors cast to another precision)
  • Added dtype checks for binary float ops (no automatic type promotion for now, so we panic for different types)

Known issue
Trying to concatenate tensors cast to a different floating point dtype will fail with the current cat_with_slice_assign implementation because the output tensor is initialized with the default floating point element precision in K::empty(...).

let tensor1: Tensor<B, 2> = Tensor::from_floats([[2., 3.], [4., 5.]], &device).cast(DType::F16);
let tensor2 = Tensor::ones_like(&tensor1).cast(DType::F16);
let tensor3 = Tensor::zeros_like(&tensor1).cast(DType::F16);

// Panics with incompatible dtype (for output)
let output = Tensor::<B, 2>::cat(vec![tensor1, tensor2, tensor3], 0);

TODO

  • Implement float_cast for ndarray
  • Remove redundant precision bridge

Testing

Added new tests for primitive trait.

Copy link

codecov bot commented Nov 19, 2024

Codecov Report

Attention: Patch coverage is 80.84337% with 318 lines in your changes missing coverage. Please review.

Project coverage is 82.53%. Comparing base (8be2032) to head (5e74f12).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-jit/src/ops/float_ops.rs 50.36% 135 Missing ⚠️
crates/burn-fusion/src/ops/float.rs 88.50% 23 Missing ⚠️
crates/burn-tch/src/tensor.rs 12.50% 21 Missing ⚠️
crates/burn-router/src/ops/op_float.rs 0.00% 14 Missing ⚠️
...s/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs 0.00% 13 Missing ⚠️
crates/burn-tensor/src/tensor/element/base.rs 45.83% 13 Missing ⚠️
crates/burn-candle/src/tensor.rs 47.36% 10 Missing ⚠️
crates/burn-jit/src/kernel/comparison.rs 79.16% 10 Missing ⚠️
crates/burn-jit/src/kernel/contiguous.rs 16.66% 10 Missing ⚠️
crates/burn-fusion/src/tensor.rs 50.00% 6 Missing ⚠️
... and 25 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2511      +/-   ##
==========================================
- Coverage   82.69%   82.53%   -0.17%     
==========================================
  Files         826      827       +1     
  Lines      106405   106689     +284     
==========================================
+ Hits        87996    88057      +61     
- Misses      18409    18632     +223     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

Comment on lines 44 to 47
pub trait Primitive: Clone + Send + Sync + core::fmt::Debug {
/// The dtype of the tensor.
fn dtype(&self) -> DType;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be a good idea to put all tensor metadata here, so to include shape.

Copy link
Member Author

@laggui laggui Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I only needed the dtype for the current changes but it makes sense to expand!

/edit: if we add that to the primitive trait, we could probably remove the *_shape ops for different tensor types and have a global implementation that uses the primitive shape.

@laggui laggui marked this pull request as ready for review November 19, 2024 18:00
@laggui
Copy link
Member Author

laggui commented Nov 21, 2024

@nathanielsimard

@laggui laggui merged commit 9a2b841 into main Nov 21, 2024
11 checks passed
@laggui laggui deleted the feat/op/jit-cast branch November 21, 2024 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants