Skip to content

Commit

Permalink
[js/webgpu] Add HardSigmoid support (microsoft#19215)
Browse files Browse the repository at this point in the history
### Description
This op is required in mobilenetv3-small-100. With this PR,
mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on
ADL.
  • Loading branch information
qjia7 authored Jan 22, 2024
1 parent 4d88884 commit f8d5281
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
Expand Down
1 change: 1 addition & 0 deletions web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
Expand Down
20 changes: 20 additions & 0 deletions web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};

export interface HardSigmoidAttributes extends AttributeWithCacheKey {
readonly alpha: number;
readonly beta: number;
}

export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
createAttributeWithCacheKey(attributes as {
alpha: number;
beta: number;
});

export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'HardSigmoid',
a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
attributes.beta})))`,
undefined, attributes.cacheKey));
};

export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
};
Expand Down
6 changes: 3 additions & 3 deletions web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,9 @@
// // "test_hardmax_example",
// // "test_hardmax_negative_axis",
// // "test_hardmax_one_hot",
// // "test_hardsigmoid_default",
// // "test_hardsigmoid_example",
// // "test_hardsigmoid",
"test_hardsigmoid_default",
"test_hardsigmoid_example",
"test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
Expand Down

0 comments on commit f8d5281

Please sign in to comment.