Skip to content

Commit

Permalink
Model Fusion For Bart (microsoft#6105)
Browse files Browse the repository at this point in the history
Fusion fix for Bart models
  • Loading branch information
liuziyue authored Dec 15, 2020
1 parent 297c824 commit 980a93c
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 25 deletions.
12 changes: 10 additions & 2 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b
continue

ort_output_names = [node_arg.name for node_arg in ort_session.get_outputs()]
output_buffers = {"last_state": None, "pooler": None}
output_buffers = []
device = "cuda" if use_gpu else "cpu"
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
max_last_state_size = numpy.prod(
Expand Down Expand Up @@ -150,16 +150,24 @@ def run_onnxruntime(use_gpu, model_names, model_class, precision, num_threads, b

logger.info("Run onnxruntime on {} with input shape {}".format(model_name,
[batch_size, sequence_length]))

if disable_ort_io_binding:
result = inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size)
else:
# Get output sizes from a dummy ort run
ort_outputs = ort_session.run(ort_output_names, ort_inputs)
output_buffer_max_sizes = [max_last_state_size]
for i in range(len(ort_outputs)):
if i == 2 and MODELS[model_name][3] == "gpt":
# past state output max size
output_buffer_max_sizes.append(max_pooler_size)
else:
output_buffer_max_sizes.append(max_last_state_size)

data_type = numpy.longlong if 'pt' in model_source else numpy.intc
result = inference_ort_with_io_binding(ort_session, ort_inputs, result_template, repeat_times,
ort_output_names, ort_outputs, output_buffers,
max_last_state_size, max_pooler_size, batch_size, device,
output_buffer_max_sizes, batch_size, device,
data_type)
logger.info(result)
results.append(result)
Expand Down
29 changes: 10 additions & 19 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def inference_ort_with_io_binding(ort_session,
ort_output_names,
ort_outputs,
output_buffers,
max_last_state_size,
max_pooler_size,
output_buffer_max_sizes,
batch_size,
device,
data_type=numpy.longlong):
Expand All @@ -219,31 +218,23 @@ def inference_ort_with_io_binding(ort_session,
for name in ort_inputs.keys():
np_input = torch.from_numpy(ort_inputs[name]).to(device)
io_binding.bind_input(name, np_input.device.type, 0, data_type, np_input.shape, np_input.data_ptr())
has_pooler = True if len(ort_output_names) == 2 else False
# Bind outputs buffers with the sizes needed if not allocated already
if output_buffers["last_state"] is None:
allocateOutputBuffers(output_buffers, max_last_state_size, max_pooler_size, device, has_pooler)
last_state_buffer = output_buffers["last_state"]
pooler_buffer = output_buffers["pooler"]
io_binding.bind_output(ort_output_names[0], last_state_buffer.device.type, 0, numpy.float32, ort_outputs[0].shape,
last_state_buffer.data_ptr())
if has_pooler:
io_binding.bind_output(ort_output_names[1], pooler_buffer.device.type, 0, numpy.float32, ort_outputs[1].shape,
pooler_buffer.data_ptr())
if len(output_buffers) == 0:
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)

for i in range(len(ort_output_names)):
io_binding.bind_output(ort_output_names[i], output_buffers[i].device.type, 0, numpy.float32, ort_outputs[i].shape,
output_buffers[i].data_ptr())
runtimes = timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=repeat_times)
result.update(result_template)
result.update({"io_binding": True})
result.update(get_latency_result(runtimes, batch_size))
return result


def allocateOutputBuffers(output_buffers, max_last_state_size, max_pooler_size, device, has_pooler=False):
def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device):
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
# dummy last state
if output_buffers["last_state"] is None:
output_buffers["last_state"] = torch.empty(max_last_state_size, dtype=torch.float32, device=device)
# create dummy pooler
if output_buffers["pooler"] is None and has_pooler:
output_buffers["pooler"] = torch.empty(max_pooler_size, dtype=torch.float32, device=device)

for i in output_buffer_max_sizes:
output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
8 changes: 8 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_skiplayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):

# In some models there is input_ids->gather->add->LayerNorm and one of input of the
# add node is initializer with fixed shape which should not be fused into SkipLayerNorm
if add is None:
return

for add_input in add.input:
if self.model.get_initializer(add_input) != None:
return
Expand All @@ -32,6 +35,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if len(self.model.get_parents(add)) != 2:
return

gather_path = self.model.match_parent_path(add, ['Gather'], [None])
if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
if self.model.match_parent_path(gather_path[0], ['ConstantOfShape'], [1]) is None:
return

if add is not None and add.op_type == 'Add' and self.model.is_safe_to_fuse_nodes(
[add, node], node.output, input_name_to_nodes, output_name_to_node):
self.nodes_to_remove.extend([add, node])
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/python/tools/transformers/huggingface_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@
"flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"),
"flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"),
# Bart
#"facebook/bart-large": (["input_ids"], 11, False, "bert"),
#"facebook/bart-base": (["input_ids"], 11, False, "bert"),
#"facebook/bart-large-mnli": (["input_ids"], 11, False, "bert"),
#"facebook/bart-large-cnn": (["input_ids"], 11, False, "bert"),
"facebook/bart-large": (["input_ids"], 11, False, "bert"),
"facebook/bart-base": (["input_ids"], 11, False, "bert"),
"facebook/bart-large-mnli": (["input_ids"], 11, False, "bert"),
"facebook/bart-large-cnn": (["input_ids"], 11, False, "bert"),
#"facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"),
# DialoGPT
"microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"),
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,38 @@ def use_dynamic_axes(self, dynamic_batch_dim='batch_size', dynamic_seq_len='max_
dim_proto.dim_param = dynamic_batch_dim

def preprocess(self):
self.adjust_reshape_and_expand()
return

def adjust_reshape_and_expand(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == 'Reshape':
# Clean up unneccessary reshape nodes.
# Find reshape nodes with no actually data in "shape" attribute and remove.
reshape_shape = self.get_constant_value(node.input[1])
if reshape_shape is not None and reshape_shape.size == 0:
nodes_to_remove.extend([node])
self.replace_input_of_all_nodes(node.output[0], node.input[0])
continue

# Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
# changing current reshape's input to output of slice.
reshape_path = self.match_parent_path(node, ['Expand', 'Expand', 'Reshape', 'Slice'], [0, 0, 0, 0],
self.output_name_to_node())
if reshape_path is not None:
expand_node = reshape_path[-3]
expand_shape_value = self.get_constant_value(expand_node.input[1])

reshape_before_expand = reshape_path[-2]
shape_value = self.get_constant_value(reshape_before_expand.input[1])

slice_node = reshape_path[-1]
if expand_shape_value is not None and shape_value is not None and len(expand_shape_value) is 2 and len(
shape_value) is 1 and expand_shape_value[1] == shape_value[0]:
node.input[0] = slice_node.output[0]
self.remove_nodes(nodes_to_remove)
logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")

def clean_graph(self):
output_name_to_node = self.output_name_to_node()
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/python/tools/transformers/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def test_huggingface_flaubert_fusion(self):
def test_huggingface_dialogpt_fusion(self):
self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0])

def test_huggingface_bart_fusion(self):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])

def test_bert_base_cased_from_tf(self):
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 1)
self._test_optimizer_on_tf_model("bert-base-cased", [1, 12, 0, 0, 12, 0, 24], 2)
Expand Down

0 comments on commit 980a93c

Please sign in to comment.