Skip to content

Commit

Permalink
[QDQ] Add shared qdq selectors (microsoft#10178)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* wip

* wip

* save

* minor changes

* update test graph name

* address pr comments

* update

* address pr comments

* address pr comments

* fix warning

* minor include fix

* update to nodegroupselectors

* delete unnecessary includes

Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
  • Loading branch information
YUNQIUGUO and rachguo authored Jan 12, 2022
1 parent 79d2a0d commit a099bd4
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ else()
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "utils.h"

#include <iostream>
#include <string>
#include <vector>

#include <core/graph/graph_viewer.h>
#include <core/providers/common.h>

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"

namespace onnxruntime {
namespace QDQ {

void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in) {
auto entry = std::make_unique<OpVersionsAndSelector>(
ops_and_versions_in,
std::move(selector_in));

ORT_IGNORE_RETURN_VALUE(selectors_set_.insert(std::move(entry)));
}

/* static methods to return different operator's OpVersionMap */
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}},
{"Reshape", {}},
{"Transpose", {}},
{"MaxPool", {12}},
{"Resize", {}}}; }

static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}},
{"LeakyRelu", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}},
{"Mul", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { return {{"Concat", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; }
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; }

/* Selector rules registration related */
void RegisterMiscSelectors(Selectors& qdq_selectors) {
/* register selectors for miscellaneous ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropQDQNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetMiscOpVersionsMap(),
std::move(selector));
}

void RegisterUnarySelectors(Selectors& qdq_selectors) {
/* regsiter selectors for unary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(),
std::move(selector));
}

void RegisterBinarySelectors(Selectors& qdq_selectors) {
/* register selectors for binary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetBinaryOpVersionsMap(),
std::move(selector));
}

void RegisterVariadicSelectors(Selectors& qdq_selectors) {
/* register selectors for variadic ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<VariadicNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetVariadicOpVersionsMap(),
std::move(selector));
}

void RegisterConvSelector(Selectors& qdq_selectors) {
/* register selector for conv op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ConvNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetConvOpVersionsMap(),
std::move(selector));
}

void RegisterMatMulSelector(Selectors& qdq_selectors) {
/* register selector for matmul op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetMatMulOpVersionsMap(),
std::move(selector));
}

void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterUnarySelectors(qdq_selectors_);
RegisterBinarySelectors(qdq_selectors_);
RegisterVariadicSelectors(qdq_selectors_);
RegisterConvSelector(qdq_selectors_);
RegisterMatMulSelector(qdq_selectors_);
}

void SelectorManager::InitializeSelectorsMap() {
for (const auto& entry : qdq_selectors_.SelectorsSet()) {
for (const auto& op_info : entry->op_versions_map) {
bool inserted = op_type_to_selectors_map_.insert({op_info.first, &*entry}).second;
ORT_ENFORCE(inserted, "Multiple entries for operator is not supported. OpType=", op_info.first);
}
}
}

void SelectorManager::Initialize() {
CreateSelectors();
InitializeSelectorsMap();
}

std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
std::vector<NodeGroup> qdq_selections;
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto* node = graph_viewer.GetNode(index);
if (node->Domain() != kOnnxDomain) {
continue;
}

auto op_rule = op_type_to_selectors_map_.find(node->OpType());
if (op_rule == op_type_to_selectors_map_.cend()) {
continue;
}

const auto& op_versions_and_selector = *op_rule->second;

// check the supported versions if specified
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
if (!versions.empty()) {
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
continue;
}
}

const auto qdq_node_group_selection = op_versions_and_selector.selector->GetQDQSelection(graph_viewer, *node);
if (qdq_node_group_selection.has_value()) {
const auto& qdq_group = *qdq_node_group_selection;
qdq_selections.push_back(qdq_group);
}
}

return qdq_selections;
}

} // namespace QDQ
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>
#include "core/graph/basic_types.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/selectors_actions/helpers.h"

namespace onnxruntime {

class GraphViewer;
class Node;

namespace QDQ {

// struct that provides a join between selector and op versions supported
struct OpVersionsAndSelector {
using OpVersionsMap = std::unordered_map<std::string, std::vector<ONNX_NAMESPACE::OperatorSetVersion>>;

OpVersionsAndSelector(const OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in)
: op_versions_map{ops_and_versions_in},
selector{std::move(selector_in)} {}

OpVersionsMap op_versions_map;
std::unique_ptr<NodeGroupSelector> selector;

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpVersionsAndSelector);
};

// class that manages a set of node group selectors
class Selectors {
public:
Selectors() = default;

// register a selector for the specified ops.
void RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in,
std::unique_ptr<NodeGroupSelector> selector_in);

const std::unordered_set<std::unique_ptr<OpVersionsAndSelector>>& SelectorsSet() const {
return selectors_set_;
}

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Selectors);

private:
std::unordered_set<std::unique_ptr<OpVersionsAndSelector>> selectors_set_;
};

// class that manages qdq node group selections
class SelectorManager {
public:
SelectorManager() = default;

void Initialize();

// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
// Can be used in QDQ support in different EPs
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const;

private:
Selectors qdq_selectors_;

std::unordered_map<std::string, const OpVersionsAndSelector*> op_type_to_selectors_map_;

void InitializeSelectorsMap();

void CreateSelectors();

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
};

} // namespace QDQ
} // namespace onnxruntime
34 changes: 34 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/environment.h"
#include "core/session/inference_session.h"
Expand Down Expand Up @@ -1826,5 +1827,38 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
ASSERT_FALSE(result.has_value());
}
}

TEST(QDQTransformerTests, QDQ_Shared_GetSelectors_Test) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");

SessionOptions so;
so.graph_optimization_level = TransformerLevel::Default;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
const Graph& graph = session_object.GetGraph();
const auto* conv_node = graph.GetNode(3);

// Make sure node 3 is the conv node
ASSERT_TRUE(nullptr != conv_node);
ASSERT_EQ("Conv", conv_node->OpType());

const GraphViewer graph_viewer(graph);

// Initialize SelectorManager
QDQ::SelectorManager selector_mgr;
selector_mgr.Initialize();

// Check if SelectorManager get a conv qdq group selection as expected
{
const auto result = selector_mgr.GetQDQSelections(graph_viewer);
ASSERT_EQ(false, result.empty());
const auto& qdq_group = result.at(0);
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
ASSERT_EQ(NodeIndex(3), qdq_group.target_node);
ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
}
}

} // namespace test
} // namespace onnxruntime

0 comments on commit a099bd4

Please sign in to comment.