Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to run llama2-burn on webgpu #1670

Closed
hlhr202 opened this issue Apr 21, 2024 · 2 comments · Fixed by #1713
Closed

Failed to run llama2-burn on webgpu #1670

hlhr202 opened this issue Apr 21, 2024 · 2 comments · Fixed by #1713
Assignees
Labels
bug Something isn't working

Comments

@hlhr202
Copy link

hlhr202 commented Apr 21, 2024

Describe the bug

I am playing Gadersd's llama2-burn

I failed to compute freq_cis when init RotaryEncodingConfig
I was using stas/tiny-random-llama-2 for testing, and I adapted the backend of llama2-burn to webgpu backend (it was originally tch backend). And I cannot init RotaryEncodingConfig since wgpu backend doesnt allow repeating the shape where dim is larger than 1.
Everything works well in tch backend.

To Reproduce

use burn_wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::Cpu;

let input: Tensor<Backend, 3> = Tensor::from_floats(
    [
        [[1.0, 2.0], [3.0, 4.0]],
        [[5.0, 6.0], [7.0, 8.0]],
        [[9.0, 10.0], [11.0, 12.0]],
        [[13.0, 14.0], [15.0, 16.0]],
    ],
    &device,
);

let repeated = input.repeat(2, 2); // panic: burn-jit-0.13.0/src/kernel/index/repeat.rs:112:9: Can only repeat dimension with dim=1

Expected behavior

use burn_tch::{LibTorch, LibTorchDevice};
type Backend = LibTorch;
let device = LibTorchDevice::Cpu;

let input: Tensor<Backend, 3> = Tensor::from_floats(
    [
        [[1.0, 2.0], [3.0, 4.0]],
        [[5.0, 6.0], [7.0, 8.0]],
        [[9.0, 10.0], [11.0, 12.0]],
        [[13.0, 14.0], [15.0, 16.0]],
    ],
    &device,
);

let repeated = input.repeat(2, 2);
/* repeated should be 
Tensor {
  data:
[[[1.0, 2.0, 1.0, 2.0],
  [3.0, 4.0, 3.0, 4.0]],
 [[5.0, 6.0, 5.0, 6.0],
  [7.0, 8.0, 7.0, 8.0]],
 [[9.0, 10.0, 9.0, 10.0],
  [11.0, 12.0, 11.0, 12.0]],
 [[13.0, 14.0, 13.0, 14.0],
  [15.0, 16.0, 15.0, 16.0]]],
  shape:  [4, 2, 4],
  device:  Cpu,
  backend:  "tch",
  kind:  "Float",
  dtype:  "f32",
}
*/

Screenshots
Can just see the code provided

Desktop (please complete the following information):

  • OS: Ubuntu 22.04.4 LTS WSL2
  • Browser: not applicable
  • Version: burn: 0.13.0

Additional context

panic!("Can only repeat dimension with dim=1");

Gadersd/llama2-burn#3

@nathanielsimard nathanielsimard self-assigned this Apr 22, 2024
@nathanielsimard nathanielsimard added the bug Something isn't working label Apr 22, 2024
@nathanielsimard nathanielsimard moved this to In Progress in Burn 🔥 Apr 22, 2024
@nathanielsimard nathanielsimard moved this from In Progress to Todo in Burn 🔥 Apr 22, 2024
@nathanielsimard
Copy link
Member

Hi @hlhr202 thanks for repporting the issue, we will try to fix the issue before the next release. In the meantime you could use multiple calls to slice_assign as a fallback for repeat.

@louisfd
Copy link
Member

louisfd commented Apr 30, 2024

It's not a wgpu specific problem: the default implementation, used by ndarray and candle, does not support repeating on dimensions larger than 1. When I wrote a kernel for wgpu I mimicked that behaviour without asking questions, but now that I think about it's just a silly limitation that is easy to overcome. I'll update both wgpu and default implementations to allow that.

@louisfd louisfd moved this from Todo to In Review in Burn 🔥 Apr 30, 2024
@github-project-automation github-project-automation bot moved this from In Review to Done in Burn 🔥 May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

3 participants