Skip to content

[Web] Wav2vec2 slower on WebGPU than WASM #21618

Closed
@gianlourbano

Description

Describe the issue

Converted the Wav2Vec2 model from torch to onnx with following script

import torch
import torch.onnx
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

model.eval()

audio_length=16000
input = torch.ones([1, audio_length])

onnx_filename="wav2vec2.onnx"
torch.onnx.export(model,
                  input,
                  onnx_filename, 
                  export_params=True,
                  do_constant_folding=True,
                  opset_version=20,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes = {
                        'input': {1: 'audio_length'},
                        'output': {1: 'audio_length'}
                        # The second dimension (index 1) is variable
                    }
)

After splitting the audio in 10s chunks, i run the model first in wasm then in webgpu, noticing that the latter is much slower than bare wasm. This doesn't happen in onnxruntime python or even torch, where using the gpu has significative improvements.

For example, on a ~21s audio, I get the following timings where each step runs the InferenceSession on a 10s chunk of audio:

webgpu
step onnx 0: 6231 ms
step onnx 1: 5424 ms
step onnx 2: 882 ms
onnx: 17171 ms
total: 17238 ms

wasm
step onnx 0: 3381 ms
step onnx 1: 3222 ms
step onnx 2: 580 ms
onnx: 9365 ms
total: 9448 ms

from the webgpu profiling i see that Conv layers are the ones taking up most of the time. I tried binding the inputs to the gpu, but it didn't improve anything.

To reproduce

Here's a link to a sample repo, instructions in README.

Urgency

Urgent, as this project is related to my thesis

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.18.0

Execution Provider

'wasm'/'cpu' (WebAssembly CPU), 'webgpu' (WebGPU)

Metadata

Assignees

No one assigned

    Labels

    ep:WebGPUort-web webgpu providermodel:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.platform:webissues related to ONNX Runtime web; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions