Skip to content

Commit

Permalink
Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add doc… (
Browse files Browse the repository at this point in the history
microsoft#1623)

* Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add documentation for ORT_TENSORRT* env vars.

* Update TensorRT-ExecutionProvider.md
  • Loading branch information
jywu-msft authored Aug 14, 2019
1 parent 09db1e0 commit 24d17f4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
9 changes: 9 additions & 0 deletions docs/execution_providers/TensorRT-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ When using the python wheel from the ONNX Runtime build with TensorRT execution

### Using onnxruntime_perf_test
You can test the performance for your ONNX Model with the TensorRT execution provider. Use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test).

### Configuring Engine Max Batch Size and Workspace Size.
By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB
One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE.
e.g. on Linux
#### override default batch size to 10
export ORT_TENSORRT_MAX_BATCH_SIZE=10
#### override default max workspace size to 2GB
export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648
21 changes: 17 additions & 4 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ using namespace ::onnxruntime::logging;

namespace onnxruntime {

// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger() {
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
return trt_logger;
}

#define CHECK_CUDA(call) \
do { \
cudaError_t status = call; \
Expand Down Expand Up @@ -197,7 +203,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect

// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -255,7 +261,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,

// Get supported node list
SubGraphCollection_t parser_nodes_vector;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -323,7 +329,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
model_proto.SerializeToString(&string_buf);

// Create TensorRT engine
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -490,7 +496,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:

// Run TRT inference
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
bool ret = trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
if (!ret) {
if (trt_state->context->getEngine().getMaxBatchSize() < batch_size) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"TRT enqueue failed: Set ORT_TRT_MAX_BATCH_SIZE environment variable to at least " + to_string(batch_size));
}
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to enqueue to TRT execution context.");
}

return Status::OK();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ TEST(ActivationOpTest, Softplus) {
return x + logf(expf(-x) + 1);
else
return logf(expf(x) + 1);
},
{}, false); // Disable TensorRT because result mismatches
});
}

TEST(ActivationOpTest, Softsign) {
Expand Down

0 comments on commit 24d17f4

Please sign in to comment.