From aa49e476b09b7ea20bee612c97765a5b0773a16c Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Wed, 16 Dec 2020 00:04:53 -0800 Subject: [PATCH] Fix TensorRT kernel conflict issue for subgraphs of control flow operators (#6115) * add static subgraph kernel index * change kernel naming to avoid conflicts --- .../tensorrt/tensorrt_execution_provider.cc | 13 +++++++------ .../tensorrt/tensorrt_execution_provider.h | 8 ++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 2b7cbf8207e37..c5cd8ee9bd699 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -259,6 +259,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, const GraphViewer& graph) const { +std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const { const std::vector& node_index = graph.GetNodesInTopologicalOrder(); std::unordered_set node_set; node_set.reserve(graph_nodes_index.first.size()); @@ -605,7 +606,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph // Assign inputs and outputs to subgraph's meta_def auto meta_def = IndexedSubGraph_MetaDef::Create(); const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; - meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(kernels_index++); + meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(subgraph_id_++); meta_def->domain() = kMSDomain; for (const auto& input : inputs) { @@ -771,11 +772,11 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& std::unordered_map index_to_node_map; std::unordered_map> input_to_nodes_map, node_to_outputs_map; std::unordered_set non_trt_node_index(node_index.begin(), node_index.end()); - int counter = 0, id = 0; + int id = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { // Construct subgraph from node list - std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); + std::unique_ptr sub_graph = GetSubGraph(group, graph); // Create node to inputs/outputs/index maps const auto& meta_def = sub_graph->GetMetaDef(); @@ -901,10 +902,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // Construct subgraph capability from node list std::vector> result; - int counter = 0, number_of_trt_nodes = 0; + int number_of_trt_nodes = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); + std::unique_ptr sub_graph = GetSubGraph(group, graph); result.push_back(ComputeCapability::Create(std::move(sub_graph))); number_of_trt_nodes += group.first.size(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index c5273c9edfd52..08acd01baeeb4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -126,9 +126,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool engine_cache_enable_ = false; std::string cache_path_; nvinfer1::IRuntime* runtime_ = nullptr; - OrtMutex tensorrt_mu_; int device_id_; + AllocatorPtr allocator_; + mutable int subgraph_id_ = 0; + std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; @@ -139,7 +141,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map>>> input_shape_ranges_; /**Get IndexedSubGraph based on node list of the subgraph*/ - std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, + std::unique_ptr GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const; /** @@ -153,7 +155,5 @@ class TensorrtExecutionProvider : public IExecutionProvider { const GraphViewer& graph, bool* early_termination) const; void RemoveTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph) const; - - AllocatorPtr allocator_; }; } // namespace onnxruntime