Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wgpu/Clamp Kernels #866

Merged
merged 19 commits into from
Oct 23, 2023
Merged

Wgpu/Clamp Kernels #866

merged 19 commits into from
Oct 23, 2023

Conversation

agelas
Copy link
Contributor

@agelas agelas commented Oct 15, 2023

Pull Request Template

Checklist

  • Confirm that run-checks script has been executed.

Related Issues/PRs

#549

Changes

  1. Added clamp kernels and shaders to wgpu backend.
  2. @antimora I used the mask_fill method that used to be the former default for the candle backend. I took a peak at candle-core and it looks like they added their own clamp method about a month ago, so we can update once burn-candle is pinned to a more recent version.

Testing

Linked backends to their respective clamp_min/clamp_max implementations instead of mask_fill.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernels looks good, but I would probably have the < and > operations be a template key. Though you would need to create your own KernelTemplate wrapper to fill the template, so it's more Rust code vs more wgsl code, it's a very optional refactor.

Comment on lines 426 to 448
fn clamp_min<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
) -> FloatTensor<Self, D> {
let mask = Self::lower_elem(tensor.clone(), min);
Self::mask_fill(tensor, mask, min)
}

fn clamp_max<const D: usize>(
tensor: FloatTensor<Self, D>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
let mask = Self::greater_elem(tensor.clone(), max);
Self::mask_fill(tensor, mask, max)
}

fn clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
Self::clamp_min(Self::clamp_max(tensor, max), min)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are the default implementations, we should implement those methods only if we actually add custom implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I take those out of here and put them back in burn-tensor/src/tensor/ops/tensor.rs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

burn-tensor/src/tensor/ops/tensor.rs Outdated Show resolved Hide resolved
burn-tensor/src/tensor/ops/tensor.rs Outdated Show resolved Hide resolved
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem i see is that your kernel is always in-place. Normally we like to have a kernel for when tensors can mut (in-place), and another for when they cannot (because there are too many references to them), which will create an output buffer. This logic is already implemented, look for unary ops. Using the unary template, you will only have to provide the logic for each value. For an example you can look at relu in burn-wgpu/src/ops/activation_ops.rs

Also, creating a third kernel for clamp (with both lower and upper bounds) would be a very low hanging fruit once you have made the other two, and it would be a great optimization as we would only launch one kernel instead of two in this case.

@agelas
Copy link
Contributor Author

agelas commented Oct 20, 2023

@louisfd I see what you're saying. I'm not super familiar with proc macros, so correct me if I'm wrong, but don't they expect static literal inputs? In other words, the unary! or unary_inplace! can't evaluate the clamp value at compile time when the macros are expanded. So as an alternative I can just write one kernel for in-place and one kernel for not. Looking at some WGSL references, it turns out clamp(value, min, max) is a built-in function so I can handle all three variations.

@louisfd
Copy link
Member

louisfd commented Oct 20, 2023

@agelas
You're right, unary is not the right one, I got fooled by the relu because it's kind of a clamp, but its clamp value is hardcoded.
It would rather be unary_scalar, where you can give a second, scalar argument.

@agelas
Copy link
Contributor Author

agelas commented Oct 21, 2023

@louisfd I think this is closer to what you're looking for now. There's no proc macros that really fit the pattern I'd need to make use of wgsl's built-in clamp, so I wrote kernels for in-place and not in-place.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks very good now!
For candle, if we only implement clamp and not clamp_min and clamp_max, then those two will fall back to their default implementation that uses a combination of kernels, so they will be slower than clamp. Looking at candle's code, it seems clamp is just a combination of minimum and maximum which are public, so it should be trivial to have clamp_min call maximum and clamp_max call minimum

@agelas
Copy link
Contributor Author

agelas commented Oct 23, 2023

@louisfd Done!

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @agelas !
Merging it

@louisfd louisfd merged commit 07c0cf1 into tracel-ai:main Oct 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants