From aa49e476b09b7ea20bee612c97765a5b0773a16c Mon Sep 17 00:00:00 2001
From: stevenlix <38092805+stevenlix@users.noreply.github.com>
Date: Wed, 16 Dec 2020 00:04:53 -0800
Subject: [PATCH] Fix TensorRT kernel conflict issue for subgraphs of control
flow operators (#6115)
* add static subgraph kernel index
* change kernel naming to avoid conflicts
---
.../tensorrt/tensorrt_execution_provider.cc | 13 +++++++------
.../tensorrt/tensorrt_execution_provider.h | 8 ++++----
2 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 2b7cbf8207e37..c5cd8ee9bd699 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -259,6 +259,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, const GraphViewer& graph) const {
+std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const {
const std::vector& node_index = graph.GetNodesInTopologicalOrder();
std::unordered_set node_set;
node_set.reserve(graph_nodes_index.first.size());
@@ -605,7 +606,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph
// Assign inputs and outputs to subgraph's meta_def
auto meta_def = IndexedSubGraph_MetaDef::Create();
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";
- meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(kernels_index++);
+ meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(subgraph_id_++);
meta_def->domain() = kMSDomain;
for (const auto& input : inputs) {
@@ -771,11 +772,11 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
std::unordered_map index_to_node_map;
std::unordered_map> input_to_nodes_map, node_to_outputs_map;
std::unordered_set non_trt_node_index(node_index.begin(), node_index.end());
- int counter = 0, id = 0;
+ int id = 0;
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
// Construct subgraph from node list
- std::unique_ptr sub_graph = GetSubGraph(group, counter, graph);
+ std::unique_ptr sub_graph = GetSubGraph(group, graph);
// Create node to inputs/outputs/index maps
const auto& meta_def = sub_graph->GetMetaDef();
@@ -901,10 +902,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
// Construct subgraph capability from node list
std::vector> result;
- int counter = 0, number_of_trt_nodes = 0;
+ int number_of_trt_nodes = 0;
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
- std::unique_ptr sub_graph = GetSubGraph(group, counter, graph);
+ std::unique_ptr sub_graph = GetSubGraph(group, graph);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
number_of_trt_nodes += group.first.size();
}
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index c5273c9edfd52..08acd01baeeb4 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -126,9 +126,11 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool engine_cache_enable_ = false;
std::string cache_path_;
nvinfer1::IRuntime* runtime_ = nullptr;
-
OrtMutex tensorrt_mu_;
int device_id_;
+ AllocatorPtr allocator_;
+ mutable int subgraph_id_ = 0;
+
std::unordered_map> parsers_;
std::unordered_map> engines_;
std::unordered_map> contexts_;
@@ -139,7 +141,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_map>>> input_shape_ranges_;
/**Get IndexedSubGraph based on node list of the subgraph*/
- std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index,
+ std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index,
const GraphViewer& graph) const;
/**
@@ -153,7 +155,5 @@ class TensorrtExecutionProvider : public IExecutionProvider {
const GraphViewer& graph, bool* early_termination) const;
void RemoveTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph) const;
-
- AllocatorPtr allocator_;
};
} // namespace onnxruntime