Skip to content

Commit

Permalink
Enable loading of ORT format model graph runtime optimizations (micro…
Browse files Browse the repository at this point in the history
…soft#9901)

Initial implementation of load/replay of runtime optimizations in an ORT format model.
  • Loading branch information
edgchen1 authored Jan 4, 2022
1 parent 9765949 commit 792db33
Show file tree
Hide file tree
Showing 38 changed files with 844 additions and 518 deletions.
14 changes: 7 additions & 7 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with suppor
option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF)
option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF)
option(onnxruntime_DISABLE_EXTERNAL_INITIALIZERS "Don't allow models to load external data" OFF)
cmake_dependent_option(onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION
"Enable runtime graph optimization of ORT format models. Warning: Not yet ready for general use."
OFF "NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD" OFF)
cmake_dependent_option(onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD
"Enable runtime graph optimization replay for ORT format models in an extended minimal build."
OFF "onnxruntime_EXTENDED_MINIMAL_BUILD" OFF)

#A special option just for debugging and sanitize check. Please do not enable in option in retail builds.
#The option has no effect on Windows.
Expand Down Expand Up @@ -340,6 +340,10 @@ if (onnxruntime_MINIMAL_BUILD)
if (onnxruntime_EXTENDED_MINIMAL_BUILD)
# enable EPs that compile kernels at runtime
add_compile_definitions(ORT_EXTENDED_MINIMAL_BUILD)

if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
endif()
endif()

if (onnxruntime_MINIMAL_BUILD_CUSTOM_OPS)
Expand Down Expand Up @@ -372,10 +376,6 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
add_compile_definitions(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
endif()

if (onnxruntime_ENABLE_LTO)
include(CheckIPOSupported)
check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output)
Expand Down
13 changes: 3 additions & 10 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc"
)

if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.cc"
"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer_utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer_utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/*.h"
Expand All @@ -37,13 +37,6 @@ else()
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/transpose_optimizer/*.cc"
)

if (onnxruntime_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/ort_format_runtime_optimization/utils.cc"
)
endif()
endif()

if (onnxruntime_ENABLE_TRAINING)
Expand Down
7 changes: 6 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ else() # minimal and/or reduced ops build
endif()
endif()

if(NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
list(APPEND onnxruntime_test_optimizer_src
"${TEST_SRC_DIR}/optimizer/runtime_optimization/graph_runtime_optimization_test.cc")
endif()

file(GLOB onnxruntime_test_training_src
"${ORTTRAINING_SOURCE_DIR}/test/model/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/gradient/*.cc"
Expand Down Expand Up @@ -687,7 +692,7 @@ AddTest(
if (MSVC)
# The warning means the type of two integral values around a binary operator is narrow than their result.
# If we promote the two input values first, it could be more tolerant to integer overflow.
# However, this is test code. We are less concerned.
# However, this is test code. We are less concerned.
target_compile_options(onnxruntime_test_all PRIVATE "/wd26451")
else()
target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses")
Expand Down
40 changes: 34 additions & 6 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#include <unordered_map>
#include <unordered_set>

#ifdef _WIN32
#pragma warning(push)
// disable some warnings from protobuf to pass Windows build
#pragma warning(disable : 4244)
#endif

#if !defined(ORT_MINIMAL_BUILD)
#include "onnx/defs/schema.h"
#else
Expand All @@ -18,6 +24,10 @@
#include "onnx/onnx_pb.h"
#include "onnx/onnx-operators_pb.h"

#ifdef _WIN32
#pragma warning(pop)
#endif

#include "gsl/gsl"

#include "core/common/common.h"
Expand All @@ -43,7 +53,7 @@ struct IndexedSubGraph;
class Model;
class OpSignature;

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
class RuntimeOptimizationRecordContainer;
#endif

Expand Down Expand Up @@ -778,7 +788,9 @@ class Graph {
return ConstGraphNodes(nodes_, std::move(filter_func));
}

/** Gets the maximum NodeIndex value used in the Graph. */
/** Gets the maximum NodeIndex value used in the Graph.
WARNING: This actually returns the max index value used + 1.
*/
int MaxNodeIndex() const noexcept { return static_cast<int>(nodes_.size()); } //assume the casting won't overflow

/** Gets the number of valid Nodes in the Graph.
Expand Down Expand Up @@ -977,7 +989,7 @@ class Graph {
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
while this is in use.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
*/
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);

Expand Down Expand Up @@ -1189,15 +1201,29 @@ class Graph {
Graph& parent_graph, const Node& parent_node,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
return runtime_optimizations_;
}

RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() {
return runtime_optimizations_;
}
#endif

// Stores information collected during the replay of loaded runtime optimizations
struct RuntimeOptimizationReplayContext {
std::unordered_map<NodeIndex, HashValue> produced_node_index_to_kernel_def_hash{};
size_t num_replayed_optimizations{};
};

const RuntimeOptimizationReplayContext& RuntimeOptimizationReplayCtx() const {
return runtime_optimization_replay_context_;
}

RuntimeOptimizationReplayContext& MutableRuntimeOptimizationReplayCtx() {
return runtime_optimization_replay_context_;
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)

// This friendship relationship should only be used to call Graph::Graph and
// Graph::LoadGraph All other access should be via the public API.
Expand Down Expand Up @@ -1414,11 +1440,13 @@ class Graph {
std::hash<std::string>, std::equal_to<std::string>>
sparse_tensor_names_;

#if defined(ORT_ENABLE_ORT_FORMAT_RUNTIME_GRAPH_OPTIMIZATION)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
// Runtime optimization storage.
// Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized
std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
RuntimeOptimizationRecordContainer& runtime_optimizations_;

RuntimeOptimizationReplayContext runtime_optimization_replay_context_;
#endif

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
25 changes: 21 additions & 4 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@

#pragma once

#include <gsl/gsl>
#include <memory>
#include <unordered_set>
#include <vector>

#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
#endif

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#endif

namespace onnxruntime {
struct FreeDimensionOverride;
class IExecutionProvider;
struct RuntimeOptimizationSaveContext;

namespace optimizer_utils {

#if !defined(ORT_MINIMAL_BUILD)

/** Generates all predefined rules for this level.
If rules_to_enable is not empty, it returns the intersection of predefined rules and rules_to_enable.
TODO: This is visible for testing at the moment, but we should rather make it private. */
Expand All @@ -40,6 +50,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
const IExecutionProvider& execution_provider /*required by constant folding*/,
const std::unordered_set<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)

/** Generates all predefined transformers which support runtime optimizations for this level.
Any transformers or rewrite rules named in rules_and_transformers_to_disable will be excluded.
Expand All @@ -48,8 +62,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
- The set of transformers which support runtime optimizations is different. */
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformersForRuntimeOptimizations(
TransformerLevel level,
const RuntimeOptimizationSaveContext& runtime_optimization_save_context,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const std::unordered_set<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)

} // namespace optimizer_utils
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses
// Save information for replaying graph optimizations later instead of applying them directly.
//
// When an ONNX model is loaded, ORT can perform various optimizations on the graph.
// However, when an ORT format model is loaded, these optimizations are typically not available - this scenario must
// be supported by minimal builds.
// However, when an ORT format model is loaded, the logic to perform these optimizations may not be available because
// this scenario must be supported by minimal builds.
// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT
// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime.
//
Expand Down
49 changes: 45 additions & 4 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,21 +980,41 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
// kernel hashes for model are in top level SessionState
const auto& compiled_kernel_hashes = GetCompiledKernelHashes();

const bool original_nodes_should_exist =
compiled_kernel_hashes.empty()
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
&& graph_.RuntimeOptimizationReplayCtx().num_replayed_optimizations == 0
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
;

// process the nodes that existed when the model was created
for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) {
const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i);

Node* const node = graph_.GetNode(node_kernel_info.node_index);
if (node == nullptr) {
// this is OK if we have compiled kernels and the original node was replaced. if not the model is invalid.
ORT_RETURN_IF(compiled_kernel_hashes.empty(),
// this is OK if we have compiled kernels/replayed runtime optimizations and the original node was replaced.
// if not the model is invalid.
ORT_RETURN_IF(original_nodes_should_exist,
"Can't find node with index ", node_kernel_info.node_index, ". Invalid ORT format model.");
continue;
}

ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, node_kernel_info.kernel_def_hash));
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
// process the nodes that were added by replaying any loaded runtime optimizations
for (const auto& [node_index, kernel_def_hash] :
graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
const auto* node = graph_.GetNode(node_index);
ORT_RETURN_IF(node == nullptr,
"Can't find runtime optimization produced node with index ", node_index);

ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_def_hash));
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)

// lookup the hashes for any nodes we compiled. the nodes indexes for compiled nodes are not in node_indices
// as they were created at runtime.
if (!compiled_kernel_hashes.empty()) {
Expand Down Expand Up @@ -1051,7 +1071,7 @@ static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordere
}

Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options,
const onnxruntime::fbs::SessionState* serialized_session_state,
bool remove_initializers,
Expand Down Expand Up @@ -1208,7 +1228,7 @@ static void AccumulateAllNestedSubgraphsInfo(
}

Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
Expand All @@ -1219,6 +1239,27 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
CreateGraphInfo();
}

#if defined(ORT_MINIMAL_BUILD) && defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
// remove any unused initializers
// not needed in a full build because unused initializers should have been removed earlier by Graph::Resolve()
// not needed in a minimal build with runtime optimizations disabled because only runtime optimizations are expected
// to possibly result in unused initializers
{
std::vector<std::string> unused_initializer_names;
for (const auto& [name, tensor_proto] : graph_.GetAllInitializedTensors()) {
ORT_UNUSED_PARAMETER(tensor_proto);
int idx;
if (!ort_value_name_idx_map_.GetIdx(name, idx).IsOK()) {
unused_initializer_names.push_back(name);
}
}

for (const auto& name : unused_initializer_names) {
graph_.RemoveInitializedTensor(name);
}
}
#endif // defined(ORT_MINIMAL_BUILD) && defined(ORT_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)

// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
std::vector<const NodeArg*> valid_outer_scope_node_args;
if (parent_node) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class SessionState {
const KernelRegistryManager& kernel_registry_manager);

Status FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options = {},
const onnxruntime::fbs::SessionState* serialized_session_state = nullptr,
bool remove_initializers = true,
Expand Down Expand Up @@ -370,7 +370,7 @@ class SessionState {
#endif

Status FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
KernelRegistryManager& kernel_registry_manager,
const KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers,
Expand Down
Loading

0 comments on commit 792db33

Please sign in to comment.