Skip to content

Commit

Permalink
Support export >2G model when using optimizer.py only (microsoft#6014)
Browse files Browse the repository at this point in the history
* checkin

* add warning if user specify same inut and output path
  • Loading branch information
wangyems authored Dec 8, 2020
1 parent b348538 commit fa06be2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/fusion_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
value=helper.make_tensor(name='const_tensor',
data_type=TensorProto.INT64,
dims=shape_value.shape,
vals=shape_value))
vals=bytes(shape_value),
raw=True))
reshape_node.input[1] = constant_shape_name
reshape_node.name = self.model.create_node_name('Reshape', 'Reshape_Fuse')
self.nodes_to_remove.extend([concat_node])
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def _parse_arguments():
default=0,
help="onnxruntime optimization level. 0 will disable onnxruntime.")

parser.add_argument('--use_external_data_format', required=False, action='store_true', help="use external data format")
parser.set_defaults(use_external_data_format=False)

args = parser.parse_args()

return args
Expand Down Expand Up @@ -321,6 +324,11 @@ def main():

_setup_logger(args.verbose)

if os.path.realpath(args.input) == os.path.realpath(args.output):
logger.warning(
f"Specified the same input and output path. Note that this may overwrite the original model"
)

optimization_options = _get_optimization_options(args)

optimizer = optimize_model(args.input,
Expand All @@ -338,7 +346,7 @@ def main():
if args.input_int32:
optimizer.change_input_to_int32()

optimizer.save_model_to_file(args.output)
optimizer.save_model_to_file(args.output, args.use_external_data_format)

if optimizer.is_fully_optimized():
logger.info("The model has been fully optimized.")
Expand Down

0 comments on commit fa06be2

Please sign in to comment.