forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[QDQ] Add shared qdq selectors (microsoft#10178)
* wip * wip * wip * wip * wip * save * minor changes * update test graph name * address pr comments * update * address pr comments * address pr comments * fix warning * minor include fix * update to nodegroupselectors * delete unnecessary includes Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
- Loading branch information
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "utils.h" | ||
|
||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include <core/graph/graph_viewer.h> | ||
#include <core/providers/common.h> | ||
|
||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" | ||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" | ||
|
||
namespace onnxruntime { | ||
namespace QDQ { | ||
|
||
void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in, | ||
std::unique_ptr<NodeGroupSelector> selector_in) { | ||
auto entry = std::make_unique<OpVersionsAndSelector>( | ||
ops_and_versions_in, | ||
std::move(selector_in)); | ||
|
||
ORT_IGNORE_RETURN_VALUE(selectors_set_.insert(std::move(entry))); | ||
} | ||
|
||
/* static methods to return different operator's OpVersionMap */ | ||
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, | ||
{"Reshape", {}}, | ||
{"Transpose", {}}, | ||
{"MaxPool", {12}}, | ||
{"Resize", {}}}; } | ||
|
||
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}}, | ||
{"LeakyRelu", {}}}; } | ||
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, | ||
{"Mul", {}}}; } | ||
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { return {{"Concat", {}}}; } | ||
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } | ||
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; } | ||
|
||
/* Selector rules registration related */ | ||
void RegisterMiscSelectors(Selectors& qdq_selectors) { | ||
/* register selectors for miscellaneous ops */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropQDQNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetMiscOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void RegisterUnarySelectors(Selectors& qdq_selectors) { | ||
/* regsiter selectors for unary ops */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void RegisterBinarySelectors(Selectors& qdq_selectors) { | ||
/* register selectors for binary ops */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetBinaryOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void RegisterVariadicSelectors(Selectors& qdq_selectors) { | ||
/* register selectors for variadic ops */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<VariadicNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetVariadicOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void RegisterConvSelector(Selectors& qdq_selectors) { | ||
/* register selector for conv op */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ConvNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetConvOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void RegisterMatMulSelector(Selectors& qdq_selectors) { | ||
/* register selector for matmul op */ | ||
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNodeGroupSelector>(); | ||
qdq_selectors.RegisterSelector(GetMatMulOpVersionsMap(), | ||
std::move(selector)); | ||
} | ||
|
||
void SelectorManager::CreateSelectors() { | ||
RegisterMiscSelectors(qdq_selectors_); | ||
RegisterUnarySelectors(qdq_selectors_); | ||
RegisterBinarySelectors(qdq_selectors_); | ||
RegisterVariadicSelectors(qdq_selectors_); | ||
RegisterConvSelector(qdq_selectors_); | ||
RegisterMatMulSelector(qdq_selectors_); | ||
} | ||
|
||
void SelectorManager::InitializeSelectorsMap() { | ||
for (const auto& entry : qdq_selectors_.SelectorsSet()) { | ||
for (const auto& op_info : entry->op_versions_map) { | ||
bool inserted = op_type_to_selectors_map_.insert({op_info.first, &*entry}).second; | ||
ORT_ENFORCE(inserted, "Multiple entries for operator is not supported. OpType=", op_info.first); | ||
} | ||
} | ||
} | ||
|
||
void SelectorManager::Initialize() { | ||
CreateSelectors(); | ||
InitializeSelectorsMap(); | ||
} | ||
|
||
std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { | ||
std::vector<NodeGroup> qdq_selections; | ||
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { | ||
const auto* node = graph_viewer.GetNode(index); | ||
if (node->Domain() != kOnnxDomain) { | ||
continue; | ||
} | ||
|
||
auto op_rule = op_type_to_selectors_map_.find(node->OpType()); | ||
if (op_rule == op_type_to_selectors_map_.cend()) { | ||
continue; | ||
} | ||
|
||
const auto& op_versions_and_selector = *op_rule->second; | ||
|
||
// check the supported versions if specified | ||
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; | ||
if (!versions.empty()) { | ||
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { | ||
LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); | ||
continue; | ||
} | ||
} | ||
|
||
const auto qdq_node_group_selection = op_versions_and_selector.selector->GetQDQSelection(graph_viewer, *node); | ||
if (qdq_node_group_selection.has_value()) { | ||
const auto& qdq_group = *qdq_node_group_selection; | ||
qdq_selections.push_back(qdq_group); | ||
} | ||
} | ||
|
||
return qdq_selections; | ||
} | ||
|
||
} // namespace QDQ | ||
} // namespace onnxruntime |
76 changes: 76 additions & 0 deletions
76
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include "core/graph/basic_types.h" | ||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" | ||
#include "core/optimizer/selectors_actions/helpers.h" | ||
|
||
namespace onnxruntime { | ||
|
||
class GraphViewer; | ||
class Node; | ||
|
||
namespace QDQ { | ||
|
||
// struct that provides a join between selector and op versions supported | ||
struct OpVersionsAndSelector { | ||
using OpVersionsMap = std::unordered_map<std::string, std::vector<ONNX_NAMESPACE::OperatorSetVersion>>; | ||
|
||
OpVersionsAndSelector(const OpVersionsMap& ops_and_versions_in, | ||
std::unique_ptr<NodeGroupSelector> selector_in) | ||
: op_versions_map{ops_and_versions_in}, | ||
selector{std::move(selector_in)} {} | ||
|
||
OpVersionsMap op_versions_map; | ||
std::unique_ptr<NodeGroupSelector> selector; | ||
|
||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpVersionsAndSelector); | ||
}; | ||
|
||
// class that manages a set of node group selectors | ||
class Selectors { | ||
public: | ||
Selectors() = default; | ||
|
||
// register a selector for the specified ops. | ||
void RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops_and_versions_in, | ||
std::unique_ptr<NodeGroupSelector> selector_in); | ||
|
||
const std::unordered_set<std::unique_ptr<OpVersionsAndSelector>>& SelectorsSet() const { | ||
return selectors_set_; | ||
} | ||
|
||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Selectors); | ||
|
||
private: | ||
std::unordered_set<std::unique_ptr<OpVersionsAndSelector>> selectors_set_; | ||
}; | ||
|
||
// class that manages qdq node group selections | ||
class SelectorManager { | ||
public: | ||
SelectorManager() = default; | ||
|
||
void Initialize(); | ||
|
||
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph | ||
// Can be used in QDQ support in different EPs | ||
std::vector<NodeGroup> GetQDQSelections(const GraphViewer& graph_viewer) const; | ||
|
||
private: | ||
Selectors qdq_selectors_; | ||
|
||
std::unordered_map<std::string, const OpVersionsAndSelector*> op_type_to_selectors_map_; | ||
|
||
void InitializeSelectorsMap(); | ||
|
||
void CreateSelectors(); | ||
|
||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); | ||
}; | ||
|
||
} // namespace QDQ | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters