Skip to content

Commit

Permalink
support opset13 in embednorm (microsoft#6866)
Browse files Browse the repository at this point in the history
wangyems authored Mar 2, 2021
1 parent 0d0eb2c commit 9073f7a
Showing 10 changed files with 99 additions and 28 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/core/optimizer/embed_layer_norm_fusion.cc
Original file line number Diff line number Diff line change
@@ -112,10 +112,10 @@ static bool MatchInputToConcatSubgraph(
const logging::Logger& logger,
const NodeIndex expected_gather_node_1_index) {
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path1{
{0, index, "Concat", {4, 11}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
{0, index, "Concat", {4, 11, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain},
};

std::vector<const Node::EdgeEnd*> edges;
@@ -145,9 +145,9 @@ static bool MatchInputToConcatSubgraph(
}

std::vector<graph_utils::EdgeEndToMatch> concat_parent_path{
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};

if (!graph_utils::FindPath(concat_node, true, concat_parent_path, edges, logger)) {
DEBUG_LOG("Failed to find path 2 of position shape.");
@@ -316,7 +316,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(

// Match Shape --> Expand path.
std::vector<const Node::EdgeEnd*> pg_edges_2;
if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1}, kOnnxDomain}}, pg_edges_2, logger)) {
if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1, 13}, kOnnxDomain}}, pg_edges_2, logger)) {
DEBUG_LOG("Failed to match Shape node. ");
return false;
}
111 changes: 91 additions & 20 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
@@ -2897,15 +2897,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) {
ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx";
static void EmbedLayerNormFusionFormat3(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@@ -2923,15 +2922,22 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx";
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) {
EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3_OpSet13) {
EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3_opset13.onnx", logger_.get());
}

static void EmbedLayerNormFusionFormat3NoCast(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@@ -2949,6 +2955,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) {
EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast_OpSet13) {
EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast_opset13.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format4.onnx";
std::shared_ptr<Model> p_model;
@@ -2977,15 +2991,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) {
ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx";
static void EmbedLayerNormFusionFormat5(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@@ -3019,15 +3032,22 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
}
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx";
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) {
EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5_OpSet13) {
EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5_opset13.onnx", logger_.get());
}

static void EmbedLayerNormFusionFormat6(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@@ -3048,6 +3068,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6_OpSet13) {
EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6_opset13.onnx", logger_.get());
}

static void TestEmbedLayerNormFusionDistilBert(const std::basic_string<ORTCHAR_T>& model_uri,
std::map<std::string, int>& op_to_count,
logging::Logger* logger) {
@@ -3076,6 +3104,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get());
@@ -3088,6 +3128,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get());
@@ -3100,15 +3152,26 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) {
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9_OpSet13) {
std::map<std::string, int> op_to_count;
TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9_opset13.onnx", op_to_count, logger_.get());
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 1);
EXPECT_EQ(op_to_count["Gather"], 2);
EXPECT_EQ(op_to_count["Unsqueeze"], 2);
EXPECT_EQ(op_to_count["ReduceSum"], 1);
}

static void EmbedLayerNormFusionFormatMultiple(const std::basic_string<ORTCHAR_T>& file_path, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger);
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
@@ -3126,6 +3189,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 2);
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple_OpSet13) {
EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple_opset13.onnx", logger_.get());
}

TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) {
auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx";
std::shared_ptr<Model> p_model;
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 9073f7a

Please sign in to comment.