Skip to content

Commit

Permalink
Add provision in ORT for session options to be parsed when available …
Browse files Browse the repository at this point in the history
…via model file (microsoft#2449)

* Initial commit

* Fix gitmodules

* Nits

* Nits

* Updates

* Update

* More changes

* Updates

* Update

* Some updates

* More changes

* Update

* Update

* Merge

* Update

* Updates

* More changes

* Update

* Fix nits

* Updates

* Fix warning

* Fix build

* Add comment

* PR feedback

* PR feedback

* Updates

* Updates

* Update

* More changes

* Fix build break

* Comment test for now

* Updates

* Updates

* PR feedback

* Updates

* Nits

* Add tests

* Fix build

* Fix build

* Fix build

* Fix build break

* Fix build

* Nits

* PR feedback

* More change

* Expose GetSessionOptions in pybind logic and add unit test for python

* Fix build

* PR feedback

* PR feedback
  • Loading branch information
hariharans29 authored Dec 4, 2019
1 parent 178d059 commit 5c2e474
Showing 21 changed files with 1,103 additions and 195 deletions.
4 changes: 3 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -46,4 +46,6 @@
[submodule "cmake/external/wil"]
path = cmake/external/wil
url = https://github.com/microsoft/wil

[submodule "cmake/external/json"]
path = cmake/external/json
url = https://github.com/nlohmann/json
28 changes: 27 additions & 1 deletion ThirdPartyNotices.txt
Original file line number Diff line number Diff line change
@@ -3794,4 +3794,30 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
SOFTWARE

-----

nlohmann/json

MIT License

Copyright (c) 2013-2019 Niels Lohmann

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
9 changes: 9 additions & 0 deletions cgmanifest.json
Original file line number Diff line number Diff line change
@@ -437,6 +437,15 @@
},
"type": "git"
}
},
{
"component": {
"git": {
"commitHash": "d98bf0278d6f59a58271425963a8422ff48fe249",
"repositoryUrl": "https://github.com/nlohmann/json.git"
},
"type": "git"
}
}
],
"Version": 1
1 change: 1 addition & 0 deletions cmake/external/json
Submodule json added at d98bf0
2 changes: 1 addition & 1 deletion cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxrun
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${PROJECT_SOURCE_DIR}/external/json ${eigen_INCLUDE_DIRS})
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
86 changes: 63 additions & 23 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
@@ -116,12 +116,12 @@ Model::Model(std::unique_ptr<ModelProto> model_proto, const IOnnxRuntimeOpSchema
// TODO: Check if we can upgrade all the current opset 6 models that are being tested
// in CI to opset 7 or above
LOGS(logger, WARNING) << "ONNX Runtime only *guarantees* support for models stamped "
"with opset version 7 or above for opset domain 'ai.onnx'. "
"Please upgrade your model to opset 7 or higher. "
"For now, this opset "
<< version
<< " model may run depending upon legacy support "
"of some older opset version operators.";
"with opset version 7 or above for opset domain 'ai.onnx'. "
"Please upgrade your model to opset 7 or higher. "
"For now, this opset "
<< version
<< " model may run depending upon legacy support "
"of some older opset version operators.";
}
// We need to overwrite the domain here with ("") or else the loop below will try to find ("")
// in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11).
@@ -284,10 +284,8 @@ Status Model::Load(std::unique_ptr<ModelProto> p_model_proto, std::shared_ptr<Mo
return Status::OK();
}

template <typename T>
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
template <typename T, typename Loader>
static Status LoadModelHelper(const T& file_path, Loader loader) {
int fd;
Status status = Env::Default().FileOpenRd(file_path, fd);
if (!status.IsOK()) {
@@ -304,8 +302,8 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
}
}
try {
status = Model::Load(fd, p_model, local_registries, logger);
} catch (std::exception& ex) {
status = loader(fd);
} catch (const std::exception& ex) {
GSL_SUPPRESS(es .84)
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
return Status(ONNXRUNTIME, FAIL, ex.what());
@@ -318,14 +316,34 @@ static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
return Env::Default().FileClose(fd);
}

template <typename T>
static Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) {
const auto loader = [&model_proto](int fd) {
return Model::Load(fd, model_proto);
};

return LoadModelHelper(file_path, loader);
}

template <typename T>
static Status LoadModel(const T& file_path, std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
const auto loader = [&p_model, local_registries, &logger](int fd) {
return Model::Load(fd, p_model, local_registries, logger);
};

return LoadModelHelper(file_path, loader);
}

template <typename T>
static Status SaveModel(Model& model, const T& file_path) {
int fd;
Status status = Env::Default().FileOpenWr(file_path, fd);
ORT_RETURN_IF_ERROR(status);
try {
status = Model::Save(model, fd);
} catch (std::exception& ex) {
} catch (const std::exception& ex) {
GSL_SUPPRESS(es .84)
ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd));
return Status(ONNXRUNTIME, FAIL, ex.what());
@@ -344,6 +362,11 @@ Status Model::Save(Model& model, const std::wstring& file_path) {
}
#endif

Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path,
ONNX_NAMESPACE::ModelProto& model_proto) {
return LoadModel(file_path, model_proto);
}

GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load
GSL_SUPPRESS(r .35)
Status Model::Load(const std::basic_string<ORTCHAR_T>& file_path, std::shared_ptr<Model>& p_model,
@@ -356,15 +379,25 @@ Status Model::Save(Model& model, const std::string& file_path) {
return SaveModel(model, file_path);
}

Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
std::unique_ptr<ModelProto> modelProto = onnxruntime::make_unique<ModelProto>();
const bool result = modelProto->ParseFromArray(p_bytes, count);
Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
const bool result = model_proto.ParseFromArray(p_bytes, count);
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
}

p_model = std::make_shared<Model>(std::move(modelProto), local_registries, logger);
return Status::OK();
}

Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
ModelProto model_proto;

auto status = LoadFromBytes(count, p_bytes, model_proto);
if (!status.IsOK()) {
return status;
}

p_model = std::make_shared<Model>(model_proto, local_registries, logger);

ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));

@@ -375,16 +408,14 @@ using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::ZeroCopyInputStream;

Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
Status Model::Load(int fd, ONNX_NAMESPACE::ModelProto& model_proto) {
if (fd < 0) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> less than 0.");
}

std::unique_ptr<ModelProto> model_proto = onnxruntime::make_unique<ModelProto>();
#if GOOGLE_PROTOBUF_VERSION >= 3002000
FileInputStream fs(fd);
const bool result = model_proto->ParseFromZeroCopyStream(&fs) && fs.GetErrno() == 0;
const bool result = model_proto.ParseFromZeroCopyStream(&fs) && fs.GetErrno() == 0;
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
}
@@ -402,7 +433,16 @@ Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOp
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
}
#endif
p_model = std::make_shared<Model>(std::move(model_proto), local_registries, logger);
return Status::OK();
}

Status Model::Load(int fd, std::shared_ptr<Model>& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger) {
ModelProto model_proto;

ORT_RETURN_IF_ERROR(Load(fd, model_proto));

p_model = std::make_shared<Model>(model_proto, local_registries, logger);

ORT_RETURN_IF_ERROR(p_model->MainGraph().Resolve(true));

13 changes: 11 additions & 2 deletions onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
@@ -26,8 +26,8 @@ class Model {
explicit Model(const std::string& graph_name,
bool is_onnx_domain_only,
const logging::Logger& logger)
:Model(graph_name,is_onnx_domain_only, ModelMetaData(),IOnnxRuntimeOpSchemaRegistryList(),{},{},
logger){}
: Model(graph_name, is_onnx_domain_only, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {}, {},
logger) {}

// Construct model from scratch.
explicit Model(const std::string& graph_name,
@@ -105,16 +105,25 @@ class Model {

static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto);

static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

// TODO(Task:132) Use of shared_ptr<X>* in Load/Save methods is confusing.
static common::Status Load(const std::basic_string<ORTCHAR_T>& file_path,
/*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);

static common::Status Load(int fd, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

static common::Status Load(int fd, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger);

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes,
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
6 changes: 6 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
@@ -157,6 +157,12 @@ class Env {
// \brief returns a provider that will handle telemetry on the current platform
virtual const Telemetry& GetTelemetryProvider() const = 0;

// \brief returns a value for the queried variable name (var_name)
//
// Returns the corresponding value stored in the environment variable if available
// Returns empty string if there is no such environment variable available
virtual std::string GetEnvironmentVar(const std::string& var_name) const = 0;

protected:
Env();

14 changes: 10 additions & 4 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ limitations under the License.
#include <dlfcn.h>
#include <string.h>
#include <thread>
#include <utility> // for std::forward
#include <utility> // for std::forward
#include <vector>
#include <assert.h>

@@ -74,7 +74,7 @@ using ScopedFileDescriptor = ScopedResource<FileDescriptorTraits>;

// non-macro equivalent of TEMP_FAILURE_RETRY, described here:
// https://www.gnu.org/software/libc/manual/html_node/Interrupted-Primitives.html
template<typename TFunc, typename... TFuncArgs>
template <typename TFunc, typename... TFuncArgs>
long int TempFailureRetry(TFunc retriable_operation, TFuncArgs&&... args) {
long int result;
do {
@@ -216,8 +216,8 @@ class PosixEnv : public Env {
}

mapped_memory = MappedMemoryPtr{
reinterpret_cast<char*>(mapped_base) + offset_to_page,
OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}};
reinterpret_cast<char*>(mapped_base) + offset_to_page,
OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}};

return Status::OK();
}
@@ -318,6 +318,12 @@ class PosixEnv : public Env {
return telemetry_provider_;
}

// \brief returns a value for the queried variable name (var_name)
std::string GetEnvironmentVar(const std::string& var_name) const override {
char* val = getenv(var_name.c_str());
return val == NULL ? std::string() : std::string(val);
}

private:
PosixEnv() = default;
Telemetry telemetry_provider_;
29 changes: 28 additions & 1 deletion onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ class WindowsEnv : public Env {

size_t total_bytes_read = 0;
while (total_bytes_read < length) {
constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time
constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time
const size_t bytes_remaining = length - total_bytes_read;
const DWORD bytes_to_read = static_cast<DWORD>(std::min<size_t>(bytes_remaining, k_max_bytes_to_read));
DWORD bytes_read;
@@ -227,6 +227,33 @@ class WindowsEnv : public Env {
return telemetry_provider_;
}

// \brief returns a value for the queried variable name (var_name)
std::string GetEnvironmentVar(const std::string& var_name) const override {
// Why getenv() should be avoided on Windows:
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv
// Instead use the Win32 API: GetEnvironmentVariableA()

// Max limit of an environment variable on Windows including the null-terminating character
constexpr DWORD kBufferSize = 32767;

// Create buffer to hold the result
char buffer[kBufferSize];

auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer, kBufferSize);

// Will be > 0 if the API call was successful
if (char_count) {
return std::string(buffer, buffer + char_count);
}

// TODO: Understand the reason for failure by calling GetLastError().
// If it is due to the specified environment variable being found in the environment block,
// GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
// For now, we assume that the environment variable is not found.

return std::string();
}

private:
WindowsEnv()
: GetSystemTimePreciseAsFileTime_(nullptr) {
Loading
Oops, something went wrong.

0 comments on commit 5c2e474

Please sign in to comment.