Skip to content

Commit

Permalink
[WebGPU] Implement tir.dp4a with WGSL built-in function `dot4I8Pack…
Browse files Browse the repository at this point in the history
…ed` (#16976)

* [WebGPU] Support `__dp4a(int8x4, int8x4)` as a pure extern method

This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.

Here is an example to use `__dp4a` in WebGPU target:

```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```

Issue: #16627

* Add validation

* Add `dot4I8Packed` to WebGPU lower intrinsic

* Implement builtin `dp4a` with `dot4I8Packed`

* Small fix

* Add missing comment
  • Loading branch information
Jiawei-Shao authored Jul 4, 2024
1 parent 0df4103 commit 3e08e70
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,14 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN
this->EndScope(else_scope);
}
os << result;
} else if (op->op.same_as(builtin::dp4a())) {
// generate `dot4I8Packed(vec1, vec2) + acc` for the builtin `dp4a`
os << "dot4I8Packed(";
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ") + ";
this->PrintExpr(op->args[2], os);
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down

0 comments on commit 3e08e70

Please sign in to comment.