Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subgraph type override handling and unit test #3560

Merged
merged 3 commits into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,8 @@ class Graph {
// perform type and shape inferencing on the subgraph and Resolve to validate
static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types);
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
const Graph::ResolveOptions& options);

// Apply type-inference and type-checking to all inputs and initializers:
common::Status TypeCheckInputsAndInitializers();
Expand Down
8 changes: 5 additions & 3 deletions include/onnxruntime/core/graph/node_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ class NodeArg {
/** Validate and merge type [and shape] info from input_type.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
ytaous marked this conversation as resolved.
Show resolved Hide resolved
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger);
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);

/** Validate and merge type [and shape] info from node_arg.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger);

common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
/** Gets this NodeArg as a ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

Expand Down
80 changes: 56 additions & 24 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ void NodeArg::ClearShape() {
}
}

common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, const logging::Logger& logger) {
common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types,
const logging::Logger& logger) {
if (!utils::HasType(node_arg_info_)) {
*node_arg_info_.mutable_type() = input_type;
type_ = DataTypeUtils::ToType(node_arg_info_.type());
Expand All @@ -229,10 +230,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
const auto& input_tensor_elem_type = input_tensor_type.elem_type();
const auto& current_tensor_elem_type = current_type.tensor_type().elem_type();

if (input_tensor_elem_type != current_tensor_elem_type)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
if (input_tensor_elem_type != current_tensor_elem_type) {
if (override_types) {
DataType inferred_type = DataTypeUtils::ToType(input_type);
// The "SetType" call will override the shape information to empty.
// If the original tensor has shape information, need to set it back.
if (Shape()) {
auto old_shape = *Shape();
SetType(inferred_type);
SetShape(old_shape);
} else {
SetType(inferred_type);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
}
}

if (utils::HasShape(input_tensor_type)) {
auto& current_tensor_type = *current_type.mutable_tensor_type();
Expand All @@ -249,11 +264,24 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
const auto& input_tensor_type = input_type.sparse_tensor_type();
const auto input_tensor_elem_type = input_tensor_type.elem_type();
const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type();

if (input_tensor_elem_type != current_tensor_elem_type) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
if (override_types) {
DataType inferred_type = DataTypeUtils::ToType(input_type);
if (Shape()) {
auto old_shape = *Shape();
SetType(inferred_type);
SetShape(old_shape);
} else {
SetType(inferred_type);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
static_cast<TensorProto_DataType>(input_tensor_elem_type), " != ",
static_cast<TensorProto_DataType>(current_tensor_elem_type));
}
}

if (utils::HasShape(input_tensor_type)) {
auto& current_tensor_type = *current_type.mutable_sparse_tensor_type();
if (utils::HasShape(current_tensor_type)) {
Expand All @@ -275,11 +303,12 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu
return Status::OK();
}

common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, const logging::Logger& logger) {
common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types,
const logging::Logger& logger) {
auto status = Status::OK();

if (utils::HasType(node_arg.node_arg_info_))
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, logger);
status = UpdateTypeAndShape(node_arg.node_arg_info_.type(), strict, override_types, logger);

return status;
}
Expand Down Expand Up @@ -771,7 +800,7 @@ Graph::Graph(const Model& owning_model,
// so we prefer the shape from the initializer
name_to_type_map[tensor.name()] = t;
if (matching_graph_input != nullptr) {
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, logger));
ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t, true, false, logger));
}
} else {
// v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these.
Expand Down Expand Up @@ -1398,12 +1427,12 @@ bool FullyDefinedType(const TypeProto& type_proto) {
// parameters are the Graph instance for the subgraph, the input types from the control flow node that contains
// the subgraph, and the vector to write the output from the inferencing.
using SubgraphInferencingFunc =
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&)>;
std::function<Status(const Node&, Graph&, const std::vector<const TypeProto*>&, std::vector<const TypeProto*>&, const Graph::ResolveOptions&)>;

class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
public:
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func)
: node_(node), graph_(graph), inferencing_func_(inferencing_func) {
GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options)
: node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) {
}

// Perform inferencing on the graph contained in GraphInferencer.
Expand All @@ -1413,7 +1442,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
const std::vector<const TensorProto*>& /*input_data*/) override {
std::vector<const TypeProto*> output_types;

auto status = inferencing_func_(node_, graph_, input_types, output_types);
auto status = inferencing_func_(node_, graph_, input_types, output_types, options_);

if (status != Status::OK()) {
fail_type_inference("Graph attribute inferencing failed: ", status.ErrorMessage());
Expand All @@ -1426,6 +1455,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
const Node& node_;
Graph& graph_;
SubgraphInferencingFunc& inferencing_func_;
const Graph::ResolveOptions& options_;
};

// An implementation of the InferenceContext interface required by operator-specific
Expand All @@ -1436,10 +1466,12 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
public:
InferenceContextImpl(Node& node,
SubgraphInferencingFunc subgraph_inferencing_func,
const Graph& graph) noexcept
const Graph& graph,
const Graph::ResolveOptions& options) noexcept
: node_(node),
subgraph_inferencing_func_(subgraph_inferencing_func),
graph_(graph) {
graph_(graph),
options_(options) {
node_output_types_.resize(node.OutputDefs().size());
}

Expand Down Expand Up @@ -1500,7 +1532,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
auto* subgraph = node_.GetMutableGraphAttribute(attribute_name);

if (subgraph) {
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_);
auto inferencer = onnxruntime::make_unique<GraphInferencerImpl>(node_, *subgraph, subgraph_inferencing_func_, options_);
graph_inferencer = inferencer.get();
graph_inferencers_.push_back(std::move(inferencer));
} else {
Expand All @@ -1518,11 +1550,13 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
SubgraphInferencingFunc subgraph_inferencing_func_;
std::vector<std::unique_ptr<GraphInferencerImpl>> graph_inferencers_;
const Graph& graph_;
const Graph::ResolveOptions& options_;
};

Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const std::vector<const TypeProto*>& input_types,
std::vector<const TypeProto*>& output_types) {
std::vector<const TypeProto*>& output_types,
const Graph::ResolveOptions& options) {
auto status = Status::OK();

output_types.clear();
Expand Down Expand Up @@ -1555,7 +1589,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
const auto& subgraph_input = *subgraph_inputs->at(i);

NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, subgraph.logger_);
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, options.override_types, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
Expand All @@ -1576,7 +1610,7 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
if (!subgraph_nodearg)
continue;

status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, subgraph.logger_);
status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg, true, options.override_types, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
Expand All @@ -1588,8 +1622,6 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,

// now that we have handled the input types, do the type/shape inferencing for the subgraph
// to flow the type/shape info through it
// TODO: Handle override-type option correctly for subgraphs.
Graph::ResolveOptions options;
status = subgraph.PerformTypeAndShapeInferencing(options);
ORT_RETURN_IF_ERROR(status);

Expand Down Expand Up @@ -1695,7 +1727,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
// Once that completes, the outputs from the node containing the subgraph will be updated, and the final values
// returned here.
SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes);
InferenceContextImpl context(node, func, *this);
InferenceContextImpl context(node, func, *this, options);

try {
context.RunInferencing();
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/test/providers/cpu/controlflow/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,116 @@ TEST(Loop, InfiniteLoopTermination) {
terminator_thread.join();
}

// Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as
// type/shape inferencing for subgraph to flow the type/shape info through
// subgraph.PerformTypeAndShapeInferencing(options).
// In this test, main graph has original input/expected output defined as "double" where the subgraph as "float".
// Expectation is types should get propagated properly in subgraph and yield correct output
//
// TODO - when the input/output type in main graph is float16, extra Cast nodes will be added and type input type
// will be changed by InsertCastTransformer for graph execution thus causes type mismatch failure.
// Need to investigate how InsertCastTransformer works in future.
TEST(Loop, SubgraphTypeOverride) {
auto create_subgraph = [](const RunOptions&) {
Model model("Loop subgraph", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();

std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;

/*
Inputs: iter_num, cond_in, fake_in, loop carried state variables.

iter_num_in cond_in fake_in [outer_scope_0]
(unused) | | |
[Identity] [Identity] [Identity]
| | |
cond_out fake_out loop_var_0_out
*/

// graph inputs types.
TypeProto int64_scalar;
int64_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int64_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

TypeProto bool_scalar;
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim();

// graph inputs
auto& iter_num_in = graph.GetOrCreateNodeArg("iter_num_in", &int64_scalar);
auto& cond_in = graph.GetOrCreateNodeArg("cond_in", &bool_scalar);
auto& fake_in = graph.GetOrCreateNodeArg("fake_in", &float_tensor);

// outer scope value. need type but not shape.
auto& outer_scope_0 = graph.GetOrCreateNodeArg("outer_scope_0", &float_tensor);

// add so that we don't end up with it being considered a graph input
graph.AddOuterScopeNodeArg("outer_scope_0");

// graph outputs
auto& cond_out = graph.GetOrCreateNodeArg("cond_out", &bool_scalar);
auto& fake_out = graph.GetOrCreateNodeArg("fake_out", &float_tensor);
auto& loop_var_0_out = graph.GetOrCreateNodeArg("loop_var_0_out", &float_tensor);

// cond_in -> cond_out
{
inputs = {&cond_in};
outputs = {&cond_out};

graph.AddNode("cond_in_identity", "Identity", "Forward cond_in to cond_out", inputs, outputs);
}

// fake_in -> fake_out
{
inputs = {&fake_in};
outputs = {&fake_out};

graph.AddNode("fake_in_identity", "Identity", "Forward fake_in to fake_out", inputs, outputs);
}

// outer_scope_0 -> loop_var_0_out
{
inputs = {&outer_scope_0};
outputs = {&loop_var_0_out};

graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
}

graph.SetInputs({&iter_num_in, &cond_in, &fake_in});
graph.SetOutputs({&cond_out, &fake_out, &loop_var_0_out});

auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());

return graph.ToGraphProto();
};

LoopOpTester test{{}, create_subgraph};

test.AddInput<int64_t>("M", {1}, {1});
test.AddInput<bool>("cond", {1}, {true});
test.AddInput<double>("fake", {1}, {0.f});
test.AddInput<double>("outer_scope_0", {1}, {kOuterNodeAddValue});

test.AddOutput<double>("loop_fake_final", {1}, {0.f});
test.AddOutput<double>("loop_var_0_final", {1, 1}, {kOuterNodeAddValue});
test.AddOutput<int64_t>("outer_scope_0_out", {1}, {int64_t(kOuterNodeAddValue)});

OrtRunOptions session_run_options;
session_run_options.run_tag = "Loop.SubgraphTypeOverride";

Graph::ResolveOptions options;
options.override_types = true;
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider}, &session_run_options, nullptr,
ExecutionMode::ORT_SEQUENTIAL, {}, options);
}

// Regression test that a subgraph input overrides an outer scope value of the same name.
// Replicate issue from https://github.com/onnx/onnx/issues/2082
TEST(Loop, SubgraphInputShadowsOuterScopeValue) {
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,15 @@ void OpTester::Run(
const RunOptions* run_options,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
ExecutionMode execution_mode,
const CustomOutputVerifierFn& custom_output_verifier) {
const CustomOutputVerifierFn& custom_output_verifier,
const Graph::ResolveOptions& options) {
SessionOptions so;
so.session_logid = op_;
so.session_log_verbosity_level = 1;
so.execution_mode = execution_mode;
so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off
Run(so, expect_result, expected_failure_string, excluded_provider_types,
run_options, execution_providers, custom_output_verifier);
run_options, execution_providers, custom_output_verifier, options);
}

void OpTester::Run(
Expand All @@ -643,7 +644,8 @@ void OpTester::Run(
const std::unordered_set<std::string>& excluded_provider_types,
const RunOptions* run_options,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers,
const CustomOutputVerifierFn& custom_output_verifier) {
const CustomOutputVerifierFn& custom_output_verifier,
const Graph::ResolveOptions& options) {
std::string cur_provider = "not set";
try {
#ifndef NDEBUG
Expand All @@ -660,12 +662,12 @@ void OpTester::Run(
expect_result == ExpectResult::kExpectFailure) {
// capture possible exceptions from shape inference for invalid testcase
try {
status = graph.Resolve();
status = graph.Resolve(options);
} catch (const std::exception& ex) {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
}
} else {
status = graph.Resolve();
status = graph.Resolve(options);
}

if (!status.IsOK()) {
Expand Down
Loading