forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add helper for optimizing a QDQ format model for usage with ORT. (mic…
…rosoft#10595) * Add initial helper for optimizing a QDQ format model for usage with ORT. If a DQ node has multiple consumers it will end up in multiple QDQ node units. This is complicated to handle as each qdq unit could end up being handled by different execution providers. By duplicating the DQ node we simplify this logic. Generally the duplicate nodes will disappear when the qdq node unit is converted to a single node with a quantized operator. If there are qdq node units that are not able to be converted to use a quantized operator the ORT cleanup (pending) to drop remaining Q->DQ pairs between fp32 nodes can remove any remaining DQ nodes. * Fix pep8 warning Co-authored-by: Guoyu Wang <wanggy@outlook.com>
1 parent
4a79ed6
commit 4d3cd2f
Showing
6 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
import onnx | ||
import os | ||
import pathlib | ||
|
||
from .qdq_model_utils import fix_dq_nodes_with_multiple_consumers | ||
|
||
|
||
def optimize_qdq_model(): | ||
parser = argparse.ArgumentParser(os.path.basename(__file__), | ||
description=''' | ||
Update a QDQ format ONNX model to ensure optimal performance when executed using | ||
ONNX Runtime. | ||
''') | ||
|
||
parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') | ||
parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.') | ||
|
||
args = parser.parse_args() | ||
|
||
model = onnx.load(str(args.input_model.resolve(strict=True))) | ||
|
||
# there's just one utility to run currently but we expect that will grow | ||
fix_dq_nodes_with_multiple_consumers(model) | ||
|
||
onnx.save(model, str(args.output_model.resolve())) | ||
|
||
|
||
if __name__ == '__main__': | ||
optimize_qdq_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import onnx | ||
from ..onnx_model_utils import get_producer_consumer_maps, iterate_graph_per_graph_func | ||
|
||
|
||
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs): | ||
updated_graphs = kwargs['updated_graphs'] | ||
node_to_consumers = kwargs['node_to_consumers'] | ||
validate_updates = kwargs['validate_updates'] | ||
|
||
nodes_to_update = [] | ||
for node in filter(lambda node: node.op_type == 'DequantizeLinear', graph.node): | ||
# node providing graph output won't have consumer nodes | ||
consumers = node_to_consumers[node] if node in node_to_consumers else [] | ||
if len(consumers) > 1: | ||
if not all(consumer in graph.node for consumer in consumers): | ||
# TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that | ||
# value as is (no need to handle recursing into the subgraph) and update the consumers in this | ||
# graph only | ||
raise IndexError("DequantizeLinear node output is consumed by a subgraph. " | ||
"This is not currently supported.") | ||
|
||
nodes_to_update.append(node) | ||
|
||
if validate_updates: | ||
if nodes_to_update: | ||
# internal error. we somehow missed an update in the first pass when validate_upates was false | ||
raise ValueError('Graph still has DequantizeLinear nodes with multiple consumers.') | ||
|
||
return | ||
|
||
if nodes_to_update: | ||
dup_idx = 0 | ||
new_graph = onnx.GraphProto() | ||
graph_outputs = set([output.name for output in graph.output]) | ||
for node in graph.node: | ||
new_graph.node.append(node) | ||
if node in nodes_to_update: | ||
is_graph_output = node.output[0] in graph_outputs | ||
# create duplicate DQ nodes as needed so that there is one consumer per node. | ||
# this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups. | ||
# if the node produces a graph output we need a duplicate DQ node for every consumer node. | ||
# if not, we can leave the first consumer as is and create duplicate nodes for the other consumers. | ||
start_idx = 0 if is_graph_output else 1 | ||
consumers = list(node_to_consumers[node])[start_idx:] | ||
|
||
for idx, consumer in enumerate(consumers): | ||
# create duplicate DQ node | ||
duplicate = onnx.NodeProto() | ||
duplicate.CopyFrom(node) | ||
# update node name for debugging. use the global dup idx for node duplication | ||
duplicate.name += f'/qdq_utils_dup_{dup_idx}' | ||
|
||
# update output. use the local idx for value duplication | ||
orig_output = node.output[0] | ||
new_output = f'{orig_output}/qdq_utils_dup_{idx}' | ||
duplicate.output[0] = new_output | ||
|
||
# update input on the consumer node. | ||
for input_idx, input_name in enumerate(consumer.input): | ||
if input_name == orig_output: | ||
consumer.input[input_idx] = new_output | ||
|
||
new_graph.node.append(duplicate) | ||
dup_idx += 1 | ||
|
||
# replace nodes | ||
del graph.node[:] | ||
graph.node.extend(new_graph.node) | ||
updated_graphs.append(graph) | ||
|
||
|
||
def fix_dq_nodes_with_multiple_consumers(model): | ||
''' | ||
Update a model if any DequantizeLinear nodes have multiple consumers. | ||
The QDQ node unit processing is overly complicated if this is the case, as the DQ node would be in multiple units, | ||
and the units may end up in different partitions at runtime. | ||
:param model: QDQ model to update | ||
''' | ||
node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) | ||
|
||
updated_graphs = [] # list of GraphProto instances that were updated_graphs | ||
iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, | ||
node_to_consumers=node_to_consumers, validate_updates=False, | ||
updated_graphs=updated_graphs) | ||
|
||
if updated_graphs: | ||
updated_graphs = [] | ||
node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) | ||
iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, | ||
node_to_consumers=node_to_consumers, validate_updates=True, | ||
updated_graphs=updated_graphs) | ||
|
||
# validate with check and by running shape inference. | ||
onnx.checker.check_model(model) | ||
_ = onnx.shape_inference.infer_shapes(model) |
Empty file.
32 changes: 32 additions & 0 deletions
32
tools/python/util/qdq_helpers/test/test_qdq_model_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import onnx | ||
import pathlib | ||
import unittest | ||
|
||
from ..qdq_model_utils import fix_dq_nodes_with_multiple_consumers | ||
|
||
script_dir = pathlib.Path(__file__).parent | ||
ort_root = script_dir.parents[4] | ||
|
||
# example usage from <ort root>/tools/python | ||
# python -m unittest util/qdq_helpers/test/test_qdq_model_utils.py | ||
# NOTE: at least on Windows you must use that as the working directory for all the imports to be happy | ||
|
||
|
||
class TestQDQUtils(unittest.TestCase): | ||
def test_fix_DQ_with_multiple_consumers(self): | ||
''' | ||
''' | ||
model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'qdq_with_multi_consumer_dq_nodes.onnx' | ||
model = onnx.load(str(model_path)) | ||
|
||
orig_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] | ||
fix_dq_nodes_with_multiple_consumers(model) | ||
new_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] | ||
|
||
# there are 3 DQ nodes with 2 consumers (an earlier Conv and later Add) | ||
# additionally the last one also provides a graph output | ||
# based on that there should be 3 new DQ nodes for the internal consumers and 1 new one for the graph output | ||
self.assertEqual(len(orig_dq_nodes) + 4, len(new_dq_nodes)) |