Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shared provider unload crash #5523

Merged
merged 6 commits into from
Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/provider_shutdown.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {
void UnloadSharedProviders();
}
85 changes: 57 additions & 28 deletions onnxruntime/core/framework/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/provider_shutdown.h"
#include "core/graph/model.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -328,7 +329,7 @@ struct ProviderHostImpl : ProviderHost {
return onnxruntime::make_unique<logging::Capture>(logger, severity, category, dataType, location);
}
void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }

// Provider_TypeProto_Tensor
int32_t Provider_TypeProto_Tensor__elem_type(const Provider_TypeProto_Tensor* p) override { return p->elem_type(); }
Expand Down Expand Up @@ -609,62 +610,95 @@ struct ProviderHostImpl : ProviderHost {
} provider_host_;

struct ProviderSharedLibrary {
ProviderSharedLibrary() {
bool Ensure() {
if (handle_)
return true;

std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return false;
}

void (*PProvider_SetHost)(void*);
Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost);

PProvider_SetHost(&provider_host_);
return true;
}

~ProviderSharedLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
}
}

ProviderSharedLibrary() = default;
~ProviderSharedLibrary() { assert(!handle_); }

private:
void* handle_{};

ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderSharedLibrary);
};

bool EnsureSharedProviderLibrary() {
static ProviderSharedLibrary shared_library;
return shared_library.handle_;
}
static ProviderSharedLibrary s_library_shared;

struct ProviderLibrary {
ProviderLibrary(const char* filename) {
if (!EnsureSharedProviderLibrary())
return;
ProviderLibrary(const char* filename) : filename_{filename} {}
~ProviderLibrary() { assert(!handle_); } // We should already be unloaded at this point

std::string full_path = Env::Default().GetRuntimePath() + std::string(filename);
Provider* Get() {
if (provider_)
return provider_;

if (!s_library_shared.Ensure())
return nullptr;

std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return nullptr;
}

Provider* (*PGetProvider)();
Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider);

provider_ = PGetProvider();
return provider_;
}

~ProviderLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
if (provider_)
provider_->Shutdown();

Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
provider_ = nullptr;
}
}

private:
const char* filename_;
Provider* provider_{};
void* handle_{};

ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
};

static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);

void UnloadSharedProviders() {
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved
s_library_dnnl.Unload();
s_library_tensorrt.Unload();
s_library_shared.Unload();
}

// This class translates the IExecutionProviderFactory interface to work with the interface providers implement
struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
IExecutionProviderFactory_Translator(std::shared_ptr<Provider_IExecutionProviderFactory> p) : p_{p} {}
Expand All @@ -677,30 +711,25 @@ struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
std::shared_ptr<Provider_IExecutionProviderFactory> p_;
};

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena) {
if (auto provider = s_library_dnnl.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(use_arena));

//return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
//TODO: This is apparently a bug. The constructor parameter is create-arena-flag, not the device-id
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
if (auto provider = s_library_tensorrt.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(device_id));

return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}

} // namespace onnxruntime

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena);
if (!factory) {
LOGS_DEFAULT(ERROR) << "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library";
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library");
}

Expand Down
34 changes: 15 additions & 19 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@
#include "dnnl_execution_provider.h"
#include "dnnl_fwd.h"

namespace {

struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};

Status st;
};

} // namespace

namespace onnxruntime {

constexpr const char* DNNL = "Dnnl";
Expand Down Expand Up @@ -62,18 +52,24 @@ Status RegisterDNNLKernels(Provider_KernelRegistry& kernel_registry) {
return Status::OK();
}

KernelRegistryAndStatus GetDnnlKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterDNNLKernels(*ret.kernel_registry);
return ret;
}
} // namespace ort_dnnl

static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;

void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
}

std::shared_ptr<Provider_KernelRegistry> DNNLExecutionProvider::Provider_GetKernelRegistry() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::shared_ptr<Provider_KernelRegistry> [](start = 0, length = 40)

The issue with giving away shared_ptr is that even if you destroy your copy of the ptr, someone will cache it someplace and then will still try to call after unloading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but onnxruntime is the only one with it at that point, so it's good. Since it is the only user of this and it is the one that destroys them.

static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}

return s_kernel_registry;
}

bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::Provider_GraphViewer& graph_viewer) const {
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using namespace onnxruntime;

namespace onnxruntime {

void Shutdown_DeleteRegistry();

struct DnnlProviderFactory : Provider_IExecutionProviderFactory {
DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {}
~DnnlProviderFactory() override {}
Expand Down Expand Up @@ -47,9 +49,10 @@ struct Dnnl_Provider : Provider {
return std::make_shared<DnnlProviderFactory>(use_arena != 0);
}

void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}

} g_provider;

} // namespace onnxruntime
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct Provider_NodeAttributes;
struct Provider_OpKernelContext;
struct Provider_OpKernelInfo;
struct Provider_Tensor;
}
} // namespace onnxruntime

#include "provider_interfaces.h"

Expand Down Expand Up @@ -127,8 +127,6 @@ enum OperatorStatus : int {

namespace onnxruntime {

void SetProviderHost(ProviderHost& host);

// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
void RunOnUnload(std::function<void()> function);
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/core/providers/shared_library/provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct Provider_IExecutionProvider {

struct Provider {
virtual std::shared_ptr<Provider_IExecutionProviderFactory> CreateExecutionProviderFactory(int device_id) = 0;
virtual void Shutdown() = 0;
};

// There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function)
Expand Down Expand Up @@ -543,35 +544,35 @@ struct CPUIDInfo {
bool HasAVX2() const { return g_host->CPUIDInfo__HasAVX2(this); }
bool HasAVX512f() const { return g_host->CPUIDInfo__HasAVX512f(this); }

PROVIDER_DISALLOW_ALL(CPUIDInfo)
PROVIDER_DISALLOW_ALL(CPUIDInfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: formatting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like clang format fixed it and this is correct now? (I might be misunderstanding it)

};

namespace logging {

struct Logger {
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }

PROVIDER_DISALLOW_ALL(Logger)
PROVIDER_DISALLOW_ALL(Logger)
};

struct LoggingManager {
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }

PROVIDER_DISALLOW_ALL(LoggingManager)
};

struct Capture {
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }

std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }

Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
};
}
} // namespace logging

struct Provider_TypeProto_Tensor {
int32_t elem_type() const { return g_host->Provider_TypeProto_Tensor__elem_type(this); }
Expand Down
26 changes: 13 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace fs = std::experimental::filesystem;
namespace {
struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
Status st;
};

std::string GetEnginePath(const ::std::string& root, const std::string& name) {
if (root.empty()) {
return name + ".engine";
Expand Down Expand Up @@ -151,17 +146,22 @@ static Status RegisterTensorrtKernels(Provider_KernelRegistry& kernel_registry)
return Status::OK();
}

KernelRegistryAndStatus GetTensorrtKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterTensorrtKernels(*ret.kernel_registry);
return ret;
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;

void Shutdown_DeleteRegistry() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is calling it ShutdownRegistry enough?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for this issue. If we have other Onnxruntime objects that need to be destroyed we can do it here too. I've been thinking of ways to prove it's clean on shutdown after this fix to prevent future issues.

s_kernel_registry.reset();
}

std::shared_ptr<Provider_KernelRegistry> TensorrtExecutionProvider::Provider_GetKernelRegistry() const {
static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = RegisterTensorrtKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}

return s_kernel_registry;
}

// Per TensorRT documentation, logger needs to be a singleton.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using namespace onnxruntime;

namespace onnxruntime {

void Shutdown_DeleteRegistry();

struct TensorrtProviderFactory : Provider_IExecutionProviderFactory {
TensorrtProviderFactory(int device_id) : device_id_(device_id) {}
~TensorrtProviderFactory() override {}
Expand Down Expand Up @@ -37,9 +39,10 @@ struct Tensorrt_Provider : Provider {
return std::make_shared<TensorrtProviderFactory>(device_id);
}

void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}

} g_provider;

} // namespace onnxruntime
Expand Down
Loading