-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wgpu/Clamp Kernels #866
Wgpu/Clamp Kernels #866
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernels looks good, but I would probably have the <
and >
operations be a template key. Though you would need to create your own KernelTemplate
wrapper to fill the template, so it's more Rust code vs more wgsl
code, it's a very optional refactor.
burn-candle/src/ops/tensor.rs
Outdated
fn clamp_min<const D: usize>( | ||
tensor: FloatTensor<Self, D>, | ||
min: FloatElem<Self>, | ||
) -> FloatTensor<Self, D> { | ||
let mask = Self::lower_elem(tensor.clone(), min); | ||
Self::mask_fill(tensor, mask, min) | ||
} | ||
|
||
fn clamp_max<const D: usize>( | ||
tensor: FloatTensor<Self, D>, | ||
max: FloatElem<Self>, | ||
) -> FloatTensor<Self, D> { | ||
let mask = Self::greater_elem(tensor.clone(), max); | ||
Self::mask_fill(tensor, mask, max) | ||
} | ||
|
||
fn clamp<const D: usize>( | ||
tensor: FloatTensor<Self, D>, | ||
min: FloatElem<Self>, | ||
max: FloatElem<Self>, | ||
) -> FloatTensor<Self, D> { | ||
Self::clamp_min(Self::clamp_max(tensor, max), min) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are the default implementations, we should implement those methods only if we actually add custom implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I take those out of here and put them back in burn-tensor/src/tensor/ops/tensor.rs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main problem i see is that your kernel is always in-place. Normally we like to have a kernel for when tensors can mut (in-place), and another for when they cannot (because there are too many references to them), which will create an output buffer. This logic is already implemented, look for unary ops. Using the unary template, you will only have to provide the logic for each value. For an example you can look at relu in burn-wgpu/src/ops/activation_ops.rs
Also, creating a third kernel for clamp (with both lower and upper bounds) would be a very low hanging fruit once you have made the other two, and it would be a great optimization as we would only launch one kernel instead of two in this case.
@louisfd I see what you're saying. I'm not super familiar with proc macros, so correct me if I'm wrong, but don't they expect static literal inputs? In other words, the |
@agelas |
@louisfd I think this is closer to what you're looking for now. There's no proc macros that really fit the pattern I'd need to make use of wgsl's built-in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks very good now!
For candle, if we only implement clamp and not clamp_min and clamp_max, then those two will fall back to their default implementation that uses a combination of kernels, so they will be slower than clamp. Looking at candle's code, it seems clamp is just a combination of minimum
and maximum
which are public, so it should be trivial to have clamp_min call maximum
and clamp_max call minimum
@louisfd Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @agelas !
Merging it
Pull Request Template
Checklist
run-checks
script has been executed.Related Issues/PRs
#549
Changes
clamp
kernels and shaders to wgpu backend.mask_fill
method that used to be the former default for the candle backend. I took a peak atcandle-core
and it looks like they added their ownclamp
method about a month ago, so we can update onceburn-candle
is pinned to a more recent version.Testing
Linked backends to their respective
clamp_min
/clamp_max
implementations instead ofmask_fill
.