Skip to content

Commit

Permalink
[QDQ] Add shared NodeUnit class (microsoft#10052)
Browse files Browse the repository at this point in the history
* initial change

* move more function to node_unit

* Remove commented code

* Minor update

* Update onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* address CR comments

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
  • Loading branch information
guoyu-wang and edgchen1 authored Dec 17, 2021
1 parent ef36488 commit f3c72de
Show file tree
Hide file tree
Showing 11 changed files with 525 additions and 259 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,8 @@ if (onnxruntime_USE_NNAPI_BUILTIN)
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
)

if(CMAKE_SYSTEM_NAME STREQUAL "Android")
Expand Down
22 changes: 13 additions & 9 deletions onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
#include <core/graph/graph_viewer.h>
#include <core/providers/common.h>

#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/shared/utils/utils.h"
#include "helper.h"
#include "op_support_checker.h"

namespace onnxruntime {
namespace nnapi {

using onnxruntime::NodeUnit;
using std::string;
using std::vector;

namespace onnxruntime {
namespace nnapi {

std::string GetErrorCause(int error_code) {
switch (error_code) {
case ANEURALNETWORKS_NO_ERROR:
Expand Down Expand Up @@ -434,22 +436,24 @@ bool IsInternalQuantizationSupported(const Node& node, const std::unordered_set<
return true;
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
bool IsNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer, const OpSupportCheckParams& params) {
const auto& op_support_checkers = GetOpSupportCheckers();
if (!Contains(op_support_checkers, node.OpType()))
if (!Contains(op_support_checkers, node_unit.OpType()))
return false;

const auto* op_support_checker = op_support_checkers.at(node.OpType());
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, params);
const auto* op_support_checker = op_support_checkers.at(node_unit.OpType());
return op_support_checker->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node_unit, params);
}

bool IsNodeSupportedInGroup(const Node& node, const GraphViewer& graph_viewer,
bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer,
const OpSupportCheckParams& params,
const std::unordered_set<std::string>& node_outputs_in_group) {
if (!IsNodeSupported(node, graph_viewer, params))
if (!IsNodeSupported(node_unit, graph_viewer, params))
return false;

// TODO, ignore this step if the node_unit is qdq node_unit
// We also want to check if the node is supported as an internal quantized node
const auto& node = node_unit.GetNode();
if (IsInternalQuantizedNode(node))
return IsInternalQuantizationSupported(node, node_outputs_in_group);
else // This is not a internal quantized node, it is supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using InitializerMap = std::unordered_map<std::string, const ONNX_NAMESPACE::Ten

class Node;
class NodeArg;
class NodeUnit;
class GraphViewer;

namespace nnapi {
Expand Down Expand Up @@ -125,11 +126,11 @@ bool GetType(const NodeArg& node_arg, int32_t& type);
void GetFlattenOutputShape(const Node& node, const Shape& input_shape, int32_t& dim_1, int32_t& dim_2);

// If a node is supported by NNAPI
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const OpSupportCheckParams& params);
bool IsNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer, const OpSupportCheckParams& params);

// If a node is supported by NNAPI in a partition node group
// `node_outputs_in_group` is the set of the output names of the nodes added to this group so far
bool IsNodeSupportedInGroup(const Node& node, const GraphViewer& graph_viewer,
bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer,
const OpSupportCheckParams& params,
const std::unordered_set<std::string>& node_outputs_in_group);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
#include <core/framework/tensorprotoutils.h>

#include "core/providers/common.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
#include "helper.h"
#include "model_builder.h"
#include "op_builder.h"
#include "op_support_checker.h"

namespace onnxruntime {
namespace nnapi {

using onnxruntime::NodeUnit;
using namespace android::nn::wrapper;
using std::vector;

namespace onnxruntime {
namespace nnapi {

ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer)
: nnapi_(NnApiImplementation()), graph_viewer_(graph_viewer) {}

Expand Down Expand Up @@ -120,7 +122,8 @@ void ModelBuilder::PreprocessInitializers() {
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
if (const auto* op_builder = GetOpBuilder(*node)) {
op_builder->AddInitializersToSkip(*this, *node);
const NodeUnit node_unit(*node);
op_builder->AddInitializersToSkip(*this, node_unit);
}
}
}
Expand Down Expand Up @@ -192,7 +195,7 @@ static Status GetInputDataType(
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
type = Type::TENSOR_FLOAT32;
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
// For ONNX the quantized input/initializer does not carry scale and zero point info
// So we will need to search the operator using this input
// And dig out the scale and zero point as the input initializers to the operator
Expand All @@ -205,9 +208,12 @@ static Status GetInputDataType(
}

// TODO, verify the scale and zero point match if there are multiple op using same input
const auto* node = all_quantized_op_inputs.at(name)[0];
const NodeUnit node_unit(*node);
ORT_RETURN_IF_ERROR(GetQuantizedInputScaleAndZeroPoint(
initializers, *all_quantized_op_inputs.at(name)[0], name, scale, zero_point));
initializers, node_unit, name, scale, zero_point));
break;
}
// case ONNX_NAMESPACE::TensorProto_DataType_INT8:
// We also do not consider ONNX_NAMESPACE::TensorProto_DataType_INT8 case here, since that can only
// be input 2 of Qlinear[Conv/MatMul], which has to be an initializer tensor and will be added
Expand Down Expand Up @@ -488,7 +494,8 @@ Status ModelBuilder::AddOperations() {
for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
if (const auto* op_builder = GetOpBuilder(*node)) {
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node));
const NodeUnit node_unit(*node);
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node_unit));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Node [", node->Name(), "], type [", node->OpType(), "] is not supported");
Expand Down
Loading

0 comments on commit f3c72de

Please sign in to comment.