Skip to content

Commit

Permalink
optimize a bert model converted using tf2onnx (microsoft#5492)
Browse files Browse the repository at this point in the history
* optimize a bert model converted using tf2onnx

* add test data

* update

* remove comments

* format

* Revert "format"

This reverts commit f8ae88c.

* Revert "remove comments"

This reverts commit 59d8a69.

* add a squeeze node to convert a 3-d mask to 2-d

* update

* update

* verify and add comments
  • Loading branch information
wangyems authored Dec 1, 2020
1 parent 3323fb6 commit 5f51689
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 27 deletions.
11 changes: 7 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,18 @@ def create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add,
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
data_type=TensorProto.FLOAT,
dims=[self.hidden_size, 3 * self.hidden_size],
vals=bytes(qkv_weight.flatten()),
raw=True)
vals=qkv_weight.flatten().tolist())
# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight)

bias = helper.make_tensor(name=attention_node_name + '_qkv_bias',
data_type=TensorProto.FLOAT,
dims=[3 * self.hidden_size],
vals=bytes(qkv_bias.flatten()),
raw=True)
vals=qkv_bias.flatten().tolist())
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias)

attnetion_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias']
Expand Down
64 changes: 44 additions & 20 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,28 @@ def __init__(self, model: OnnxModel):

def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Layer Norm from Keras in Tensorflow:
+----------------------+
| |
| v (B) (B) (A)
Add --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
| | | ^ ^
| | | | |
| +----------------------------------------------------------------------------|-------+ |
| v |
+-------------------------------------------------------------------------------------> Mul--------------------+
Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
+------------------------------------+
| |
| |
(Cast_1) |
| |
| v (B) (B) (A)
Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
| | | ^ ^
| | | | |
| +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
| v |
+---------------------------------------------------------------------------------------------------------------> Mul--------------------+
"""
return_indice = []
parent_nodes = self.model.match_parent_path(
_, parent_nodes, return_indice = self.model.match_parent_paths(
node,
['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, None, 0, 0, 0, None, 0, 0, None],
output_name_to_node,
return_indice=return_indice) # yapf: disable
[(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, None, 0, 0, 0, None, 0, 0, None]),
(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'Cast', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, None, 0, 0, 0, 0, None, 0, 0, None])],
output_name_to_node) # yapf: disable

if parent_nodes is None:
return
Expand All @@ -148,24 +152,35 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
return

sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0 = parent_nodes[:6]
reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]

cast_node_3 = None
if len(parent_nodes) == 11:
cast_node_3 = parent_nodes[6]
assert(cast_node_3.op_type == 'Cast')

mul_node_3 = self.model.match_parent(node, 'Mul', 0, output_name_to_node)
if mul_node_3 is None:
logger.debug("mul_node_3 not found")
return

root_node = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
root_node = node_before_reduce if cast_node_3 is None else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
if root_node is None:
logger.debug("root node is none")
return

i, epsilon = self.model.get_constant_input(add_node_0)
if epsilon is None or epsilon <= 0 or epsilon > 1.0E-5:
if epsilon is None or epsilon <= 0 or (epsilon > 1.0E-5 and cast_node_3 is None):
logger.debug("epsilon is not matched")
return

if reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input:
if cast_node_3 is None and (reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input):
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
return

if cast_node_3 is not None and (node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input):
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
return

Expand All @@ -177,6 +192,14 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0,
mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3
]

if cast_node_3 is not None:
cast_node_2 = self.model.match_parent(mul_node_0, 'Cast', 0, output_name_to_node)
if cast_node_2 is None:
logger.debug("cast_node_2 not found")
return
subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])

if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.model.input_name_to_nodes(),
self.model.output_name_to_node()):
logger.debug("not safe to fuse layer normalization")
Expand All @@ -189,7 +212,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):

#TODO: add epsilon attribute
fused_node = helper.make_node('LayerNormalization',
inputs=[reduce_mean_node_1.input[0], weight_input, bias_input],
inputs=[mul_node_3.input[0], weight_input, bias_input],
outputs=[node.output[0]])
fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
self.nodes_to_add.append(fused_node)

125 changes: 123 additions & 2 deletions onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse
import numpy as np
from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
from onnx import ModelProto, TensorProto, numpy_helper, helper
from onnx_model_bert import BertOnnxModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -295,6 +295,126 @@ def process_embedding(self):
self.prune_graph()
break

def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
for x in [matmul_q, matmul_k, matmul_v]:
root_input = x.input[0]
root_node = output_name_to_node[root_input]
if root_node == parent:
continue
logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
return False

return True

def fuse_attention(self):
output_name_to_node = self.output_name_to_node()

nodes_to_remove = []
attention_count = 0

skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
for normalize_node in skip_layer_norm_nodes:
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
parent = self.get_parent(normalize_node, 1)
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
parent = self.get_parent(normalize_node, 0)
if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]:
logger.debug("Failed to match parent of normalize_node")
continue

qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'],
[0, 0, 0, 0, 0])
if qkv_nodes is None:
qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'],
[1, 0, 0, 0])
if qkv_nodes is None:
logger.debug("Failed to match qkv nodes")
continue

(reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes[-3:]
v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("Failed to match v path")
continue

(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0])
if qk_nodes is None:
logger.debug("Failed to match qk_paths")
continue
(softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes

q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("Failed to match q path")
continue
(transpose_q, reshape_q, add_q, matmul_q) = q_nodes

k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0])
if k_nodes is None:
logger.debug("Failed to match k path")
continue
(transpose_k, reshape_k, add_k, matmul_k) = k_nodes

mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Unsqueeze'], [1, 0, 1])
if mask_nodes is None:
mask_nodes = self.match_parent_path(add_qk, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Mul'], [1, 0, 1, 0, 0])
if mask_nodes is None:
logger.debug("Failed to match mask path")
continue

if not self.has_constant_input(mask_nodes[1], 1):
logger.debug("Sub node expected to have an input with constant value 1.0.")
continue

# add a squeeze node to convert a 3-d mask to 2-d
squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0])
squeeze_node_name = "Squeeze_3d_to_2d_mask"
squeeze_output_name = squeeze_node_name + "_output"
if squeeze_node is None and len(mask_nodes) == 5:
mask_input = mask_nodes[-1].input[1]
self.add_node(
helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1]))
mask_nodes[-1].input[0] = squeeze_output_name

is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
if is_same_root:
mask_index = self.attention_mask.process_mask(squeeze_output_name)
logger.debug("Create an Attention node.")
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
add_q, add_k, add_v, parent.output[0],
reshape_qkv.output[0])
if parent.op_type == 'Reshape':
# Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
tensor = helper.make_tensor(
name=parent.name + "_modified",
data_type=TensorProto.INT64,
dims=[3],
vals=np.int64([[1, -1, hidden_size]]).tobytes(),
raw=True)
self.add_initializer(tensor)
parent.input[1] = parent.name + "_modified"

if attention_node is None:
continue

self.add_node(attention_node)
attention_count += 1

nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
nodes_to_remove.extend(qk_nodes)
nodes_to_remove.extend(q_nodes)
nodes_to_remove.extend(k_nodes)
nodes_to_remove.extend(v_nodes)
nodes_to_remove.extend(mask_nodes)
else:
logger.debug("Root node not matched.")
continue
self.remove_nodes(nodes_to_remove)
self.update_graph()
logger.info(f"Fused Attention count:{attention_count}")

def preprocess(self):
self.remove_identity()
self.process_embedding()
Expand All @@ -315,4 +435,5 @@ def remove_reshape_before_first_attention(self):

def postprocess(self):
self.remove_reshape_before_first_attention()
self.prune_graph()
# Temporary work around for the following comment as it will cause topological issues for a bert model
# self.prune_graph()
Binary file not shown.
17 changes: 16 additions & 1 deletion onnxruntime/python/tools/transformers/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"gpt2_past": ('gpt2_pytorch1.5_opset11', 'gpt2_past.onnx'),
"gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'),
"multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'),
"bert_tf2onnx_0": ('other_models', 'bert_tf2onnx_0.onnx')
}

skip_on_ort_version = pytest.mark.skipif(onnxruntime.__version__ == ('1.3.0'),
Expand Down Expand Up @@ -297,6 +298,20 @@ def test_multiple_embed(self):
}
self.verify_node_count(model, expected_node_count, 'test_multiple_embed')

def test_bert_tf2onnx_0(self):
input = _get_test_model_path('bert_tf2onnx_0')
model = optimize_model(input, 'bert_tf', num_heads=2, hidden_size=8)
expected_node_count = {
'EmbedLayerNormalization': 0,
'Attention': 6,
'Gelu': 0,
'FastGelu': 6,
'BiasGelu': 0,
'LayerNormalization': 0,
'SkipLayerNormalization': 13
}
self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')

def test_huggingface_bert_fusion(self):
self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)
self.test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2)
Expand Down Expand Up @@ -343,4 +358,4 @@ def test_huggingface_dialogpt_fusion(self):


if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 5f51689

Please sign in to comment.