Skip to content

Commit

Permalink
Make NNAPI EP reject nodes with no-shape inputs (microsoft#5927)
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang authored Nov 25, 2020
1 parent fddbd89 commit 8736865
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ class BaseOpSupportChecker : public IOpSupportChecker {
return 27;
}

virtual bool HasSupportedInputs(const Node& node) const;
virtual bool HasSupportedInputsImpl(const Node& node) const;

virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 13; }

private:
bool HasSupportedOpSet(const Node& node) const;
bool HasSupportedInputs(const Node& node) const;
};

/* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker(
Expand Down Expand Up @@ -121,16 +124,23 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
}

bool BaseOpSupportChecker::HasSupportedInputs(const Node& node) const {
// We do not support unknown(null) input shape
for (const auto* input : node.InputDefs()) {
if (!input->Shape()) {
LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
<< "] Input [" << input->Name() << "] has no shape";
return false;
}
}

return HasSupportedInputsImpl(node);
}

bool BaseOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
// We only check the type of input 0 by default
// specific op builder can override this
const auto& input = *node.InputDefs()[0];

if (nullptr == input.Shape()) {
LOGS_DEFAULT(VERBOSE) << "[" << node.OpType()
<< "] Input shape is null";
return false;
}

int32_t input_type;
if (!GetType(input, input_type))
return false;
Expand Down Expand Up @@ -170,7 +180,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
int32_t GetMinSupportedSdkVer(const Node& node, const OpSupportCheckParams& params) const override;
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
int GetMinSupportedOpSet(const Node& node) const override;
};

Expand Down Expand Up @@ -206,9 +216,9 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const Node& node) const {
return 1;
}

bool BinaryOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool BinaryOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearAdd")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);

// QLinearAdd
if (!HasValidBinaryOpQuantizedInputs(node))
Expand Down Expand Up @@ -511,7 +521,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
return params.use_nchw ? 29 : 28;
}

bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
};

/* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker(
Expand All @@ -524,9 +534,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
});
}

bool ConvOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool ConvOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearConv")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);

// QLinearConv only supports input of uint8 for now
if (!HasValidBinaryOpQuantizedInputs(node))
Expand Down Expand Up @@ -683,13 +693,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
int GetMinSupportedOpSet(const Node& node) const override;
};

bool GemmOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool GemmOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
if (node.OpType() != "QLinearMatMul")
return BaseOpSupportChecker::HasSupportedInputs(node);
return BaseOpSupportChecker::HasSupportedInputsImpl(node);

// QLinearMatMul
if (!HasValidBinaryOpQuantizedInputs(node))
Expand Down Expand Up @@ -990,7 +1000,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker {
int32_t GetMinSupportedSdkVer(const Node& /* node */, const OpSupportCheckParams& /* params */) const override {
return 29;
}
bool HasSupportedInputs(const Node& node) const override;
bool HasSupportedInputsImpl(const Node& node) const override;
};

bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
Expand All @@ -1007,7 +1017,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor
return true;
}

bool DequantizeLinearOpSupportChecker::HasSupportedInputs(const Node& node) const {
bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const Node& node) const {
int32_t input_type;
if (!GetType(*node.InputDefs()[0], input_type))
return false;
Expand Down
47 changes: 47 additions & 0 deletions onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,53 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
<< "Some nodes should have been taken by the NNAPI EP";
#endif
}

TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) {
const ORTCHAR_T* model_file_name = ORT_TSTR("input_with_no_shape_test_graph.onnx");

{ // Create the model with 2 add nodes, the graph has 2 inputs with no shape
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;

// FLOAT tensor without shape
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);

auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
inputs.push_back(&input_arg_1);
inputs.push_back(&input_arg_2);
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
outputs.push_back(&output_arg);
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);

auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
inputs.clear();
inputs.push_back(&output_arg);
inputs.push_back(&input_arg_3);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);

ASSERT_STATUS_OK(graph.Resolve());
ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name));
}

// test load only
// since we know NNAPI supports Add op, but both Add ops in the graph has no input shape
// verify the entire graph will not be assigned to NNAPI EP
SessionOptions so;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_EQ(CountAssignedNodes(session_object.GetGraph(), kNnapiExecutionProvider), 0)
<< "No node should be taken by the NNAPI EP";
}

#endif // !(ORT_MINIMAL_BUILD

TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
Expand Down

0 comments on commit 8736865

Please sign in to comment.