Skip to content

Commit

Permalink
Update benchmark to reflect those used in our latest results (#3967)
Browse files Browse the repository at this point in the history
Update optimizer for GPT2 models exported from PyTorch 1.5.
Update benchmark to use GPT2 models without Past State inputs/outputs
Update bert_perf_test to allow setting omp_num_threads etc to test only one setting
  • Loading branch information
tianleiwu authored May 17, 2020
1 parent 56700be commit 07e9a4c
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 166 deletions.
64 changes: 36 additions & 28 deletions onnxruntime/python/tools/transformers/Gpt2OnnxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class Gpt2OnnxModel(BertOnnxModel):

def __init(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)

Expand All @@ -28,7 +27,6 @@ def fuse_attention(self):
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()

nodes_to_remove = []
attention_count = 0

for normalize_node in self.get_nodes_by_op_type("LayerNormalization"):
Expand Down Expand Up @@ -65,10 +63,39 @@ def fuse_attention(self):
continue

qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0])
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
continue
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
if qk_nodes is not None:
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
mask_nodes = self.match_parent_path(
sub_qk,
['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
continue
div_mask = mask_nodes[-1]

if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
continue
else:
# New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0])
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
continue
(softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes
mask_nodes = self.match_parent_path(
where_qk,
['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
[ 0, 0, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
continue
div_mask = mask_nodes[-1]

if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
continue

q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
if q_nodes is None:
Expand All @@ -88,31 +115,12 @@ def fuse_attention(self):
logger.debug("fuse_attention: skip since split_v != split_k")
continue

mask_nodes = self.match_parent_path(
sub_qk,
['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
continue
(mul_mask, sub_mask, slice_mask, slice_mask_0, unsqueeze_mask, sub_mask, squeeze_mask, slice_mask_1,
shape_mask, div_mask) = mask_nodes

if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
continue

self.create_attention_node(gemm, gemm_qkv, layernorm_before_attention.output[0], reshape_qkv.output[0],
attention_count == 0)
nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
nodes_to_remove.extend(qk_nodes)
nodes_to_remove.extend(q_nodes)
nodes_to_remove.extend(k_nodes)
nodes_to_remove.extend(v_nodes)
nodes_to_remove.extend(mask_nodes)
# we rely on prune_graph() to clean old subgraph nodes:
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
attention_count += 1

self.remove_nodes(nodes_to_remove)
self.prune_graph()
logger.info(f"Fused Attention count:{attention_count}")

Expand Down Expand Up @@ -187,4 +195,4 @@ def postprocess(self):
reshape_count += 2

self.prune_graph()
logger.info(f"Remove Reshape count:{reshape_count}")
logger.info(f"postprocess: remove Reshape count:{reshape_count}")
117 changes: 68 additions & 49 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Benchmarking the inference of pretrained transformer models
""" Benchmarking the inference of pretrained transformer models.
PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
One difference is that random input_ids is generated in this benchmark.
This will automatically export a pretrained model to ONNX, and do optimization (optional)
For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
Example commands:
Export all models to ONNX, optimize and validate them:
Expand Down Expand Up @@ -48,36 +48,49 @@

logger = logging.getLogger('')

DEFAULT_MODELS = ["bert-base-cased", "distilbert-base-uncased", "roberta-base", "gpt2"]

# List of pretrained models: https://huggingface.co/transformers/pretrained_models.html
# Pretrained model name to a tuple of input names, opset_version and optimization model type
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format and optimization model type
MODELS = {
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 11, "bert"),
"distilbert-base-uncased": (["input_ids", "attention_mask"], 11, "bert"),
"roberta-base": (["input_ids", "attention_mask"], 11, "bert"),
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 11, False, "bert"),
"distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
"roberta-base": (["input_ids", "attention_mask"], 11, False, "bert"),

# No past state inputs for GPT models. You might set enable_cache=False in modeling_gpt2.py (need installing transformers from source) to disable past state output.
"gpt2": (["input_ids"], 11, "gpt2"),
"distilgpt2": (["input_ids"], 11, "gpt2"),
"openai-gpt": (["input_ids"], 11, "gpt2"),
# No past state inputs for GPT models.
"gpt2": (["input_ids"], 11, False, "gpt2"), # no past state inputs & outputs
"distilgpt2": (["input_ids"], 11, False, "gpt2"), # no past state inputs & outputs
"openai-gpt": (["input_ids"], 11, False, "gpt2"), # no past state inputs

# Models uses Einsum, which need opset version 12 and PyTorch 1.5.0 or above.
# Currently OnnxRuntime lacks cuda op for Einsum. GPU inference will be very slow.
"albert-base-v2": (["input_ids"], 12, "bert"),
"xlnet-base-cased": (["input_ids"], 12, "bert"),
"albert-base-v2": (["input_ids"], 12, False, "bert"),
"xlnet-base-cased": (["input_ids"], 12, False, "bert"),

# This model is very large. Need use_external_data_format=True to export it.
"xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"),
}

cpu_count = psutil.cpu_count(logical=True)
# Set OMP environment variable before importing onnxruntime or torch.
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = str(cpu_count)

from transformers import (AutoConfig, AutoTokenizer, is_torch_available)
import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model)


# use_cache is True by default in GPT2Model. Here we wrap a class to disable past state output.
class GPT2ModelNoPastState(GPT2Model):
def __init__(self, config):
super().__init__(config)

def forward(self, input_ids):
return super().forward(input_ids, use_cache=False)


if is_torch_available():
import torch
from transformers import AutoModel
def load_pretrained_model(model_name, config, cache_dir):
if model_name in ["gpt2", "distilgpt2"]:
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
return AutoModel.from_pretrained(model_name, config=config, cache_dir=cache_dir)


def create_onnxruntime_session(onnx_model_path, use_gpu):
Expand Down Expand Up @@ -176,16 +189,22 @@ def validate_onnx_model(onnx_model_filename, example_inputs, example_outputs_fla
optimize_model_statistics = {}


def optimize_onnx_model(onnx_model_filename, model_type, num_attention_heads, hidden_size, fp16):
def optimize_onnx_model(onnx_model_filename, model_type, num_attention_heads, hidden_size, fp16, overwrite):
optimized_model_filename = onnx_model_filename.replace(".onnx", "_fp16.onnx" if fp16 else "_fp32.onnx")
if not os.path.exists(optimized_model_filename):
if overwrite or not os.path.exists(optimized_model_filename):
from optimizer import optimize_model
from BertOnnxModel import BertOptimizationOptions
optimization_options = BertOptimizationOptions(model_type)
if fp16:
optimization_options.enable_gelu_approximation = True

# Use onnxruntime to optimize model, which will be saved to *_ort_cpu.onnx
opt_model = optimize_model(onnx_model_filename,
model_type,
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=99,
optimization_options=optimization_options,
only_onnxruntime=True)
optimize_model_statistics[onnx_model_filename] = opt_model.get_fused_operator_statistics()

Expand All @@ -194,7 +213,7 @@ def optimize_onnx_model(onnx_model_filename, model_type, num_attention_heads, hi
model_type,
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=0)
opt_level=99)
optimize_model_statistics[optimized_model_filename] = opt_model.get_fused_operator_statistics()

if fp16:
Expand All @@ -205,9 +224,9 @@ def optimize_onnx_model(onnx_model_filename, model_type, num_attention_heads, hi
return optimized_model_filename


def export_onnx_model(model_name, cache_dir, input_names, fp16, optimize_onnx, validate_onnx):
def export_onnx_model(model_name, cache_dir, input_names, fp16, optimize_onnx, validate_onnx, overwrite):
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModel.from_pretrained(model_name, config=config, cache_dir=cache_dir)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir)
model.cpu()

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
Expand All @@ -223,7 +242,7 @@ def export_onnx_model(model_name, cache_dir, input_names, fp16, optimize_onnx, v
example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])

onnx_model_filename = "{}_{}.onnx".format(model_name, str(len(input_names)))
if not os.path.exists(onnx_model_filename):
if overwrite or not os.path.exists(onnx_model_filename):
logger.info("Exporting ONNX model to {}".format(onnx_model_filename))

dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
Expand All @@ -236,7 +255,8 @@ def export_onnx_model(model_name, cache_dir, input_names, fp16, optimize_onnx, v
example_outputs=example_outputs,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=MODELS[model_name][1])
opset_version=MODELS[model_name][1],
use_external_data_format=MODELS[model_name][2])
else:
logger.info(f"Skip export since model existed: {onnx_model_filename}")

Expand All @@ -245,9 +265,9 @@ def export_onnx_model(model_name, cache_dir, input_names, fp16, optimize_onnx, v
is_valid_onnx_model = validate_onnx_model(onnx_model_filename, example_inputs, example_outputs_flatten)

if optimize_onnx or fp16:
model_type = MODELS[model_name][2]
model_type = MODELS[model_name][3]
onnx_model_filename = optimize_onnx_model(onnx_model_filename, model_type, config.num_attention_heads,
config.hidden_size, fp16)
config.hidden_size, fp16, overwrite)

if validate_onnx:
is_valid_onnx_model = validate_onnx_model(onnx_model_filename, example_inputs, example_outputs_flatten)
Expand All @@ -272,7 +292,7 @@ def get_latency_result(runtimes, batch_size):


def run_onnxruntime(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repeat_times, input_counts,
optimize_onnx, validate_onnx, cache_dir, verbose):
optimize_onnx, validate_onnx, cache_dir, verbose, overwrite):
import onnxruntime

results = []
Expand All @@ -295,7 +315,7 @@ def run_onnxruntime(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, r

with torch.no_grad():
onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model(
model_name, cache_dir, input_names, fp16, optimize_onnx, validate_onnx)
model_name, cache_dir, input_names, fp16, optimize_onnx, validate_onnx, overwrite)
if not is_valid_onnx_model:
continue

Expand Down Expand Up @@ -347,19 +367,21 @@ def run_pytorch(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repea

for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
model = AutoModel.from_pretrained(model_name, config=config, cache_dir=cache_dir)
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
max_input_size = tokenizer.max_model_input_sizes[model_name]
logger.debug(f"Model {model}")
logger.debug(f"Number of parameters {model.num_parameters()}")

if fp16:
model.half()

device = torch.device("cuda:0" if use_gpu else "cpu")
model.to(device)

for batch_size in batch_sizes:
if batch_size <= 0:
continue
if fp16:
model.half()
device = torch.device("cuda:0" if use_gpu else "cpu")
model.to(device)

for sequence_length in sequence_lengths:
if max_input_size is not None and sequence_length > max_input_size:
Expand All @@ -372,13 +394,8 @@ def run_pytorch(use_gpu, model_names, fp16, batch_sizes, sequence_lengths, repea
dtype=torch.long,
device=device)
try:
if torchscript:
logger.debug("Tracing model with input shape {}".format(input_ids.shape))
inference = torch.jit.trace(model, input_ids)
inference(input_ids)
else:
inference = model
inference(input_ids)
inference = torch.jit.trace(model, input_ids) if torchscript else model
inference(input_ids)

runtimes = timeit.repeat(lambda: inference(input_ids), repeat=repeat_times, number=1)

Expand Down Expand Up @@ -452,11 +469,15 @@ def output_summary(results, csv_filename, args):


def output_fusion_statistics(optimize_model_statistics, csv_filename):
from transformers import __version__ as transformers_version
with open(csv_filename, mode="a", newline='') as csv_file:
column_names = ["model_filename"] + list(next(iter(optimize_model_statistics.values())).keys())
column_names = ["model_filename", "transformers", "torch"] + list(
next(iter(optimize_model_statistics.values())).keys())
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for key in optimize_model_statistics.keys():
optimize_model_statistics[key]["transformers"] = transformers_version
optimize_model_statistics[key]["torch"] = torch.__version__
optimize_model_statistics[key]["model_filename"] = key
csv_writer.writerow(optimize_model_statistics[key])
logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
Expand Down Expand Up @@ -496,6 +517,8 @@ def parse_arguments():

parser.add_argument("--verbose", required=False, action="store_true", help="Print more information")

parser.add_argument("--overwrite", required=False, action="store_true", help="Overwrite existing models")

parser.add_argument("-o",
"--optimize_onnx",
required=False,
Expand Down Expand Up @@ -543,7 +566,7 @@ def setup_logger(verbose):
coloredlogs.install(level='DEBUG', fmt='[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s')
else:
coloredlogs.install(fmt='%(message)s')
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.WARNING)


def main():
Expand All @@ -569,10 +592,6 @@ def main():

results = []
if enable_torch or enable_torchscript:
if not is_torch_available():
logger.error("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
return

if args.input_counts != [1]:
logger.warning("--input_counts is not implemented for torch or torchscript engine.")

Expand All @@ -588,7 +607,7 @@ def main():
try:
results += run_onnxruntime(args.use_gpu, args.models, args.fp16, args.batch_sizes, args.sequence_lengths,
args.test_times, args.input_counts, args.optimize_onnx, args.validate_onnx,
args.cache_dir, args.verbose)
args.cache_dir, args.verbose, args.overwrite)
except:
logger.error(f"Exception", exc_info=True)

Expand Down
Loading

0 comments on commit 07e9a4c

Please sign in to comment.