Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable loading of ORT format model graph runtime optimizations #9901

Merged
merged 39 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fa734f5
Update comment.
edgchen1 Nov 4, 2021
ef529ed
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Nov 8, 2021
99f2e07
Limit scope of disabling warning 4244 in graph.cc and graph_viewer.cc.
edgchen1 Nov 9, 2021
35d87fb
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Nov 9, 2021
f596aba
Load runtime optimizations from SAT.
edgchen1 Nov 15, 2021
94323c7
continue implementation
edgchen1 Nov 16, 2021
a3d6015
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Nov 16, 2021
65e9fcd
Fix build.
edgchen1 Nov 16, 2021
90272b7
Set node EP and kernel create infos based on produced node hashes.
edgchen1 Nov 23, 2021
e3fdaa7
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Nov 23, 2021
adf3445
Add temporary HashValue.
edgchen1 Nov 23, 2021
4d87515
save work
edgchen1 Nov 30, 2021
a9e15ea
add commented out experimental code
edgchen1 Dec 1, 2021
4ccdca9
Fix code removing unused initializers.
edgchen1 Dec 1, 2021
fd747e8
Refine test and fix duplicate records.
edgchen1 Dec 1, 2021
49c803e
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 1, 2021
86f611a
update ifdef condition and comments
edgchen1 Dec 1, 2021
4fcbdb3
Fix build.
edgchen1 Dec 1, 2021
a19e85a
Remove unused function.
edgchen1 Dec 1, 2021
b16bb4b
Rename variable.
edgchen1 Dec 2, 2021
7b0261e
Clarify comment.
edgchen1 Dec 2, 2021
1e2b394
Add num_replayed_optimizations to RuntimeOptimizationReplayContext.
edgchen1 Dec 3, 2021
91a9a0a
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 8, 2021
5a6ca4d
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 8, 2021
00c1ed4
Merge branch 'edgchen1/sat_runtime_optimization_load' of github.com:m…
edgchen1 Dec 10, 2021
a496718
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 14, 2021
1c07356
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 17, 2021
8ea3907
Use single test model generation script.
edgchen1 Dec 17, 2021
8f602bb
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Dec 28, 2021
e7adc7e
Include runtime optimization test based on CMake option.
edgchen1 Dec 28, 2021
d21910b
Add noexcept to SelectorsAndActions constructors.
edgchen1 Dec 28, 2021
bf5b498
Refactor selector action transformer helper classes.
edgchen1 Dec 30, 2021
09d2406
Address PR comments.
edgchen1 Dec 30, 2021
facfa3d
Fix unused parameter warning.
edgchen1 Dec 30, 2021
ee7a6e6
Add comment.
edgchen1 Jan 3, 2022
01ba186
Change ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION to ORT_ENABLE…
edgchen1 Jan 3, 2022
4606bcc
Merge remote-tracking branch 'origin/master' into edgchen1/sat_runtim…
edgchen1 Jan 3, 2022
4bd99e7
Fix merge.
edgchen1 Jan 3, 2022
d74fa54
Remove git checkout step from build definition.
edgchen1 Jan 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ if (onnxruntime_MINIMAL_BUILD)

if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.cc"
"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer_utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer_utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"
Expand All @@ -37,13 +37,6 @@ else()
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.cc"
)

if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.cc"
)
endif()
endif()

if (onnxruntime_ENABLE_TRAINING)
Expand Down
5 changes: 5 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ else() # minimal and/or reduced ops build
endif()
endif()

if(NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD)
list(APPEND onnxruntime_test_optimizer_src
"${TEST_SRC_DIR}/optimizer/runtime_optimization/graph_runtime_optimization_test.cc")
endif()

file(GLOB onnxruntime_test_training_src
"${ORTTRAINING_SOURCE_DIR}/test/model/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/gradient/*.cc"
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/graph/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto;
using InitializedTensorSet = std::unordered_map<std::string, const ONNX_NAMESPACE::TensorProto*>;
using ArgNameToTypeMap = std::unordered_map<std::string, ONNX_NAMESPACE::TypeProto>;
using ProviderType = const std::string&;
using HashValue = uint64_t; // TODO remove after https://github.com/microsoft/onnxruntime/pull/9710 is merged
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved

// TODO - Evaluate switching the types below to support transparent comparators and enable
// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations
// converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>>
Expand Down
33 changes: 30 additions & 3 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#include <unordered_map>
#include <unordered_set>

#ifdef _WIN32
#pragma warning(push)
// disable some warnings from protobuf to pass Windows build
#pragma warning(disable : 4244)
#endif

#if !defined(ORT_MINIMAL_BUILD)
#include "onnx/defs/schema.h"
#else
Expand All @@ -18,6 +24,10 @@
#include "onnx/onnx_pb.h"
#include "onnx/onnx-operators_pb.h"

#ifdef _WIN32
#pragma warning(pop)
#endif

#include "gsl/gsl"

#include "core/common/common.h"
Expand Down Expand Up @@ -779,7 +789,9 @@ class Graph {
return ConstGraphNodes(nodes_, std::move(filter_func));
}

/** Gets the maximum NodeIndex value used in the Graph. */
/** Gets the maximum NodeIndex value used in the Graph.
WARNING: This actually returns the max index value used + 1.
*/
int MaxNodeIndex() const noexcept { return static_cast<int>(nodes_.size()); } //assume the casting won't overflow

/** Gets the number of valid Nodes in the Graph.
Expand Down Expand Up @@ -978,7 +990,7 @@ class Graph {
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
while this is in use.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
*/
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);

Expand Down Expand Up @@ -1198,7 +1210,20 @@ class Graph {
RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() {
return runtime_optimizations_;
}
#endif

// Stores information collected during the replay of loaded runtime optimizations
struct RuntimeOptimizationReplayContext {
std::unordered_map<NodeIndex, HashValue> produced_node_index_to_kernel_def_hash;
};

const RuntimeOptimizationReplayContext& RuntimeOptimizationReplayCtx() const {
return runtime_optimization_replay_context_;
}

RuntimeOptimizationReplayContext& MutableRuntimeOptimizationReplayCtx() {
return runtime_optimization_replay_context_;
}
#endif // defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);
Expand Down Expand Up @@ -1421,6 +1446,8 @@ class Graph {
// Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized
std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
RuntimeOptimizationRecordContainer& runtime_optimizations_;

RuntimeOptimizationReplayContext runtime_optimization_replay_context_;
#endif

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
25 changes: 21 additions & 4 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@

#pragma once

#include <gsl/gsl>
#include <memory>
#include <unordered_set>
#include <vector>

#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
#endif

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#endif

namespace onnxruntime {
struct FreeDimensionOverride;
class IExecutionProvider;
struct RuntimeOptimizationSaveContext;

namespace optimizer_utils {

#if !defined(ORT_MINIMAL_BUILD)

/** Generates all predefined rules for this level.
If rules_to_enable is not empty, it returns the intersection of predefined rules and rules_to_enable.
TODO: This is visible for testing at the moment, but we should rather make it private. */
Expand All @@ -40,6 +50,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
const IExecutionProvider& execution_provider /*required by constant folding*/,
const std::unordered_set<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD)

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)

/** Generates all predefined transformers which support runtime optimizations for this level.
Any transformers or rewrite rules named in rules_and_transformers_to_disable will be excluded.

Expand All @@ -48,8 +62,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
- The set of transformers which support runtime optimizations is different. */
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformersForRuntimeOptimizations(
TransformerLevel level,
const RuntimeOptimizationSaveContext& runtime_optimization_save_context,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const std::unordered_set<std::string>& rules_and_transformers_to_disable = {});

#endif // defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)

} // namespace optimizer_utils
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses
// Save information for replaying graph optimizations later instead of applying them directly.
//
// When an ONNX model is loaded, ORT can perform various optimizations on the graph.
// However, when an ORT format model is loaded, these optimizations are typically not available - this scenario must
// be supported by minimal builds.
// However, when an ORT format model is loaded, these optimizations may not be available because this scenario must be
// supported by minimal builds.
// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT
// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime.
//
Expand Down
49 changes: 45 additions & 4 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -982,21 +982,41 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
// kernel hashes for model are in top level SessionState
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();

const bool original_nodes_should_exist =
compiled_kernel_hashes.empty()
#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
&& graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash.empty()
#endif // defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
;

// process the nodes that existed when the model was created
for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) {
const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i);

Node* const node = graph_.GetNode(node_kernel_info.node_index);
if (node == nullptr) {
// this is OK if we have compiled kernels and the original node was replaced. if not the model is invalid.
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
// this is OK if we have compiled kernels/replayed runtime optimizations and the original node was replaced.
// if not the model is invalid.
ORT_RETURN_IF(original_nodes_should_exist,
"Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model.");
continue;
}

ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, node_kernel_info.kernel_def_hash));
}

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
// process the nodes that were added by replaying any loaded runtime optimizations
for (const auto& [node_index, kernel_def_hash] :
graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
const auto* node = graph_.GetNode(node_index);
ORT_RETURN_IF(node == nullptr,
"Can't find runtime optimization produced node with index ", node_index);

ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_def_hash));
}
#endif // defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)

// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
// as they were created at runtime.
if (!compiled_kernel_hashes.empty()) {
Expand Down Expand Up @@ -1053,7 +1073,7 @@ static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordere
}

Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options,
const onnxruntime::fbs::SessionState* serialized_session_state,
bool remove_initializers,
Expand Down Expand Up @@ -1210,7 +1230,7 @@ static void AccumulateAllNestedSubgraphsInfo(
}

Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
Expand All @@ -1221,6 +1241,27 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
CreateGraphInfo();
}

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION) && defined(ORT_MINIMAL_BUILD)
// remove any unused initializers
// not needed in a full build because unused initializers should have been removed earlier by Graph::Resolve()
// not needed in a minimal build with runtime optimizations disabled because only runtime optimizations are expected
// to possibly result in unused initializers
{
std::vector<std::string> unused_initializer_names;
for (const auto& [name, tensor_proto] : graph_.GetAllInitializedTensors()) {
ORT_UNUSED_PARAMETER(tensor_proto);
int idx;
if (!ort_value_name_idx_map_.GetIdx(name, idx).IsOK()) {
unused_initializer_names.push_back(name);
}
}

for (const auto& name : unused_initializer_names) {
graph_.RemoveInitializedTensor(name);
}
}
#endif // defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION) && defined(ORT_MINIMAL_BUILD)

// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
std::vector<const NodeArg*> valid_outer_scope_node_args;
if (parent_node) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class SessionState {
const KernelRegistryManager& kernel_registry_manager);

Status FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options = {},
const onnxruntime::fbs::SessionState* serialized_session_state = nullptr,
bool remove_initializers = true,
Expand Down Expand Up @@ -373,7 +373,7 @@ class SessionState {
#endif

Status FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _WIN32
// disable some warnings from protobuf to pass Windows build
#pragma warning(disable : 4244)
#endif
#include "core/graph/graph.h"

#include <cassert>
#include <fstream>
Expand Down Expand Up @@ -2858,7 +2855,7 @@ static void RemoveRepeatedFieldEntry(T& repeated_field, const TIter& entry_to_re
// we do this so we don't have to move all the entries past the one being deleted down one.
auto slot = entry_to_remove - repeated_field.begin();
auto last_entry = repeated_field.end() - 1;
repeated_field.SwapElements(slot, num_entries - 1);
repeated_field.SwapElements(gsl::narrow<int>(slot), gsl::narrow<int>(num_entries - 1));
repeated_field.erase(last_entry);
} else {
repeated_field.erase(entry_to_remove);
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _WIN32
// disable some warnings from protobuf to pass Windows build
#pragma warning(disable : 4244)
#endif

#include "core/graph/graph_viewer.h"
#include "core/graph/indexed_sub_graph.h"

Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/graph/runtime_optimization_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <limits>
#include <vector>
#include <string>
#include <tuple> // for std::tie

#include "core/graph/basic_types.h"

Expand Down Expand Up @@ -60,6 +61,18 @@ an ORT format model. This also means that non-empty node indices here must be in
int num_variadic_inputs;
/** The number of variadic output values of the target node. */
int num_variadic_outputs;

friend bool operator==(const NodesToOptimizeIndices& a, const NodesToOptimizeIndices& b) {
const auto tied = [](const NodesToOptimizeIndices& n) {
return std::tie(n.nodes, n.num_inputs, n.num_outputs, n.variadic_input, n.variadic_output,
n.num_variadic_inputs, n.num_variadic_outputs);
};
return tied(a) == tied(b);
}

friend bool operator!=(const NodesToOptimizeIndices& a, const NodesToOptimizeIndices& b) {
return !(a == b);
}
};

struct NodeIndexAndKernelDefHash {
Expand Down
Loading