Skip to content

Commit

Permalink
[WebNN] Fix bug in SkipSimplifiedLayerNormalization (#23236)
Browse files Browse the repository at this point in the history
The input should be added by skip and bias (if it exits) firstly.
  • Loading branch information
Honry authored Jan 7, 2025
1 parent 655b3ef commit 519fae0
Showing 1 changed file with 19 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("epsilon", epsilon);

emscripten::val output = emscripten::val::undefined();
// SkipSimplifiedLayerNormalization's output: input_skip_bias_sum.
emscripten::val input_skip_bias_sum = emscripten::val::undefined();
if (op_type == "BatchNormalization") {
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
Expand Down Expand Up @@ -107,14 +105,31 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
| | | | | |
Y:2 axis B:epsilon A:X A:scale B:bias
If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists,
If it is SkipSimplifiedLayerNormalization, X should be input_skip_bias_sum:
input_skip_bias_sum = X + skip + bias (if it exists)
*/

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
emscripten::val common_options = emscripten::val::object();

// If it is SkipSimplifiedLayerNormalization, add the skip and bias (if it exists) to the input.
if (op_type == "SkipSimplifiedLayerNormalization") {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, bias, common_options);
}

// Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists.
// Now input equals to input_skip_bias_sum.
if (TensorExists(output_defs, 3)) {
model_builder.AddOperand(output_defs[3]->Name(), input);
}
}

// Pow
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
common_options.set("label", node.Name() + "_pow");
Expand Down Expand Up @@ -146,24 +161,11 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
common_options.set("label", node.Name() + "_mul");
output = model_builder.GetBuilder().call<emscripten::val>("mul", scale, div, common_options);

// Add (if bias exits)
// Add (if bias exists)
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_bias");
output = model_builder.GetBuilder().call<emscripten::val>("add", output, bias, common_options);
}

// SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias.
if (op_type == "SkipSimplifiedLayerNormalization" && TensorExists(output_defs, 3)) {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>(
"add", input_skip_bias_sum, bias, common_options);
}
model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum));
}
}
} else if (op_type == "InstanceNormalization") {
// WebNN spec only supports 4D input for instanceNormalization.
Expand Down

0 comments on commit 519fae0

Please sign in to comment.