You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I failed to compute freq_cis when init RotaryEncodingConfig
I was using stas/tiny-random-llama-2 for testing, and I adapted the backend of llama2-burn to webgpu backend (it was originally tch backend). And I cannot init RotaryEncodingConfig since wgpu backend doesnt allow repeating the shape where dim is larger than 1.
Everything works well in tch backend.
To Reproduce
use burn_wgpu::{AutoGraphicsApi,Wgpu,WgpuDevice};typeBackend = Wgpu<AutoGraphicsApi,f32,i32>;let device = WgpuDevice::Cpu;let input:Tensor<Backend,3> = Tensor::from_floats([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]],[[9.0,10.0],[11.0,12.0]],[[13.0,14.0],[15.0,16.0]],],&device,);let repeated = input.repeat(2,2);// panic: burn-jit-0.13.0/src/kernel/index/repeat.rs:112:9: Can only repeat dimension with dim=1
Hi @hlhr202 thanks for repporting the issue, we will try to fix the issue before the next release. In the meantime you could use multiple calls to slice_assign as a fallback for repeat.
It's not a wgpu specific problem: the default implementation, used by ndarray and candle, does not support repeating on dimensions larger than 1. When I wrote a kernel for wgpu I mimicked that behaviour without asking questions, but now that I think about it's just a silly limitation that is easy to overcome. I'll update both wgpu and default implementations to allow that.
Describe the bug
I am playing Gadersd's llama2-burn
I failed to compute freq_cis when init RotaryEncodingConfig
I was using stas/tiny-random-llama-2 for testing, and I adapted the backend of llama2-burn to webgpu backend (it was originally tch backend). And I cannot init RotaryEncodingConfig since wgpu backend doesnt allow repeating the shape where dim is larger than 1.
Everything works well in tch backend.
To Reproduce
Expected behavior
Screenshots
Can just see the code provided
Desktop (please complete the following information):
Additional context
burn/crates/burn-jit/src/kernel/index/repeat.rs
Line 112 in 1433284
Gadersd/llama2-burn#3
The text was updated successfully, but these errors were encountered: