forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support to save onnx graph with external initializers file. (micr…
…osoft#6911) Add functionality to the Graph class to be dumped to protobuf using an external binary file for the float initializers. This change is meant to avoid hitting the 2GB protobuf limit when dumping large graphs. This limit was particularly easy to exceed when dumping graphs after auto-diff. The use of the external file is limited to initializers larger than a user-specified threshold. This gives the possibility to users to include in the onnx file shape constants used by Reshape and Transpose used by Shape Inference.
1 parent
12b5ab3
commit 0315878
Showing
7 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
onnxruntime/test/framework/save_model_with_external_initializers.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/framework/data_types.h" | ||
#include "core/graph/model.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
#include "test/test_environment.h" | ||
#include "test_utils.h" | ||
#include "test/util/include/asserts.h" | ||
|
||
#include "gtest/gtest.h" | ||
|
||
using namespace ONNX_NAMESPACE; | ||
using namespace onnxruntime; | ||
|
||
namespace onnxruntime { | ||
namespace test { | ||
|
||
void LoadSaveAndCompareModel(const std::string& input_onnx, | ||
const std::string& output_onnx, | ||
const std::string& external_init_file, | ||
size_t initializer_size_threshold) { | ||
std::shared_ptr<Model> model; | ||
ASSERT_STATUS_OK(Model::Load(ToPathString(input_onnx), model, nullptr, DefaultLoggingManager().DefaultLogger())); | ||
std::remove(output_onnx.c_str()); | ||
std::remove(external_init_file.c_str()); | ||
ASSERT_STATUS_OK(Model::SaveWithExternalInitializers(*model, ToPathString(output_onnx), external_init_file, initializer_size_threshold)); | ||
|
||
std::shared_ptr<Model> model_from_external; | ||
ASSERT_STATUS_OK(Model::Load(ToPathString(output_onnx), model_from_external, nullptr, DefaultLoggingManager().DefaultLogger())); | ||
|
||
Graph& graph = model->MainGraph(); | ||
// Perform shape inference on the graph, if this succeeds then it means that we could correctly read the | ||
// integer initializers used by reshape and transpose. | ||
ASSERT_STATUS_OK(graph.Resolve()); | ||
Graph& graph_from_external = model_from_external->MainGraph(); | ||
|
||
InitializedTensorSet initializers = graph.GetAllInitializedTensors(); | ||
InitializedTensorSet initializers_from_external = graph_from_external.GetAllInitializedTensors(); | ||
|
||
ASSERT_EQ(initializers.size(), initializers_from_external.size()); | ||
|
||
// Compare the initializers of the two versions. | ||
for (auto i : initializers) { | ||
const std::string kInitName = i.first; | ||
const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; | ||
const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; | ||
|
||
size_t tensor_proto_size = 0; | ||
std::unique_ptr<uint8_t[]> tensor_proto_data; | ||
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, Path(), tensor_proto_data, tensor_proto_size)); | ||
|
||
size_t from_external_tensor_proto_size = 0; | ||
std::unique_ptr<uint8_t[]> from_external_tensor_proto_data; | ||
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*from_external_tensor_proto, Path(), from_external_tensor_proto_data, from_external_tensor_proto_size)); | ||
|
||
if (from_external_tensor_proto_size < initializer_size_threshold) { | ||
// 'Small' tensors should be embedded in the onnx file. | ||
EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT); | ||
} else { | ||
// 'Large' tensors should be added to the external binary file. | ||
EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); | ||
} | ||
|
||
ASSERT_EQ(tensor_proto_size, from_external_tensor_proto_size); | ||
EXPECT_EQ(memcmp(tensor_proto_data.get(), from_external_tensor_proto_data.get(), tensor_proto_size), 0); | ||
} | ||
// Cleanup. | ||
ASSERT_EQ(std::remove(output_onnx.c_str()), 0); | ||
ASSERT_EQ(std::remove(external_init_file.c_str()), 0); | ||
} | ||
|
||
TEST(SaveWithExternalInitializers, Mnist) { | ||
LoadSaveAndCompareModel("testdata/mnist.onnx", "testdata/mnist_with_external_initializers.onnx", "mnist_external_initializers.bin", 100); | ||
} | ||
|
||
} // namespace test | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters