-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficient implementation of Tensor::ones()
for metal
#2512
Conversation
candle-metal-kernels/src/lib.rs
Outdated
|
||
encoder.use_resource(output, metal::MTLResourceUsage::Write); | ||
|
||
let grid_size = MTLSize { width: length as u64, height: 1, depth: 1 }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is length
here correct? Shouldn't it divided by the number of threads? Anyway better to use utils::linear_split
as it should do the proper thing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was stupid of me. Resolved.
@LaurentMazare |
Yes please. |
Thanks! |
…#2512) * WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
While working with
Tensor::ones()
I realized the implementation is significantly slower thanpytorch
counterpart inmetal
. The original implementation would create aCpuStorage
and move it to Gpu.This implementation introduces a metal kernel function for
Tensor::ones()
impl.Benchmarks (original vs new implementation)
Benchmark Code ran on Mac M1 Pro 16GB
u8
u8
u8
u32
u32
u32
f32
f32
f32
bf16
bf16
bf16
f16
f16
f16
i64
i64
i64
Looks like we are getting some significant performance boost.
Implementation Details:
MetalDevice::ones_impl
now introduces a code-path similar to therandom
implementation - falls back toCpuStorage
based implementation forDType::f64
fill
- following the pattern implemented fornormal
metal kernelscandle-core
test cases to test forbf16
andf16
typescandle-metal-kernels
for theconstant_fill
related code introducedPlease let me know if I'm missing something!