forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add tool for generating test data for longformer (microsoft#6415)
- Loading branch information
Showing
1 changed file
with
241 additions
and
0 deletions.
There are no files selected for viewing
241 changes: 241 additions & 0 deletions
241
onnxruntime/python/tools/transformers/longformer/generate_test_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
#------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
#-------------------------------------------------------------------------- | ||
|
||
# Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency. | ||
|
||
import sys | ||
import argparse | ||
import numpy as np | ||
import os | ||
import random | ||
from pathlib import Path | ||
from onnx import ModelProto, TensorProto, numpy_helper | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | ||
from onnx_model import OnnxModel | ||
from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--model', required=True, type=str, help="bert onnx model path.") | ||
|
||
parser.add_argument('--output_dir', | ||
required=False, | ||
type=str, | ||
default=None, | ||
help="output test data path. If not specified, .") | ||
|
||
parser.add_argument('--batch_size', required=False, type=int, default=1, help="batch size of input") | ||
|
||
parser.add_argument('--sequence_length', | ||
required=False, | ||
type=int, | ||
default=128, | ||
help="maximum sequence length of input") | ||
|
||
parser.add_argument('--global_tokens', required=False, type=int, default=10, help="number of global tokens") | ||
|
||
parser.add_argument('--input_ids_name', required=False, type=str, default=None, help="input name for input ids") | ||
|
||
parser.add_argument('--input_mask_name', | ||
required=False, | ||
type=str, | ||
default=None, | ||
help="input name for attention mask") | ||
|
||
parser.add_argument('--global_mask_name', | ||
required=False, | ||
type=str, | ||
default=None, | ||
help="input name for global attention mask") | ||
|
||
parser.add_argument('--samples', required=False, type=int, default=1, help="number of test cases to be generated") | ||
|
||
parser.add_argument('--seed', required=False, type=int, default=3, help="random seed") | ||
|
||
parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information") | ||
parser.set_defaults(verbose=False) | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None): | ||
""" | ||
Get graph inputs for longformer model. | ||
""" | ||
model = ModelProto() | ||
with open(onnx_file, "rb") as f: | ||
model.ParseFromString(f.read()) | ||
|
||
onnx_model = OnnxModel(model) | ||
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers() | ||
|
||
if input_ids_name is not None: | ||
input_ids = onnx_model.find_graph_input(input_ids_name) | ||
if input_ids is None: | ||
raise ValueError(f"Graph does not have input named {input_ids_name}") | ||
|
||
input_mask = None | ||
if input_mask_name: | ||
input_mask = onnx_model.find_graph_input(input_mask_name) | ||
if input_mask is None: | ||
raise ValueError(f"Graph does not have input named {input_mask_name}") | ||
|
||
global_mask = None | ||
if global_mask_name: | ||
global_mask = onnx_model.find_graph_input(global_mask_name) | ||
if global_mask is None: | ||
raise ValueError(f"Graph does not have input named {global_mask_name}") | ||
|
||
expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0) | ||
if len(graph_inputs) != expected_inputs: | ||
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}") | ||
|
||
return input_ids, input_mask, global_mask | ||
|
||
if len(graph_inputs) != 3: | ||
raise ValueError("Expect the graph to have 3 inputs. Got {}".format(len(graph_inputs))) | ||
|
||
# Try guess the inputs based on naming. | ||
input_ids = None | ||
input_mask = None | ||
global_mask = None | ||
for input in graph_inputs: | ||
input_name_lower = input.name.lower() | ||
if "global" in input_name_lower: | ||
global_mask = input | ||
elif "mask" in input_name_lower: | ||
input_mask = input | ||
else: | ||
input_ids = input | ||
|
||
if input_ids and input_mask and global_mask: | ||
return input_ids, input_mask, global_mask | ||
|
||
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.") | ||
|
||
|
||
def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens): | ||
""" | ||
Fake data based on the graph input of segment_ids. | ||
Args: | ||
segment_ids (TensorProto): graph input of input tensor. | ||
Returns: | ||
data (np.array): the data for input tensor | ||
""" | ||
data_type = global_mask.type.tensor_type.elem_type | ||
assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] | ||
|
||
if num_global_tokens > 0: | ||
assert num_global_tokens <= sequence_length | ||
data = np.zeros((batch_size, sequence_length), dtype=np.int32) | ||
temp = np.ones((batch_size, num_global_tokens), dtype=np.int32) | ||
data[:temp.shape[0], :temp.shape[1]] = temp | ||
else: | ||
data = np.zeros((batch_size, sequence_length), dtype=np.int32) | ||
|
||
if data_type == TensorProto.FLOAT: | ||
data = np.float32(data) | ||
elif data_type == TensorProto.INT64: | ||
data = np.int64(data) | ||
|
||
return data | ||
|
||
|
||
def fake_test_data(batch_size, | ||
sequence_length, | ||
test_cases, | ||
dictionary_size, | ||
verbose, | ||
random_seed, | ||
input_ids, | ||
input_mask, | ||
global_mask, | ||
num_global_tokens, | ||
random_mask_length=False): | ||
""" | ||
Generate fake input data for test. | ||
""" | ||
assert input_ids is not None | ||
|
||
np.random.seed(random_seed) | ||
random.seed(random_seed) | ||
|
||
all_inputs = [] | ||
for _ in range(test_cases): | ||
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size) | ||
inputs = {input_ids.name: input_1} | ||
|
||
if input_mask: | ||
inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length) | ||
|
||
if global_mask: | ||
inputs[global_mask.name] = fake_global_mask_data(global_mask, batch_size, sequence_length, | ||
num_global_tokens) | ||
|
||
if verbose and len(all_inputs) == 0: | ||
print("Example inputs", inputs) | ||
all_inputs.append(inputs) | ||
|
||
return all_inputs | ||
|
||
|
||
def generate_test_data(batch_size, | ||
sequence_length, | ||
test_cases, | ||
seed, | ||
verbose, | ||
input_ids, | ||
input_mask, | ||
global_mask, | ||
num_global_tokens, | ||
random_mask_length=False): | ||
dictionary_size = 10000 | ||
all_inputs = fake_test_data(batch_size, sequence_length, test_cases, dictionary_size, verbose, seed, input_ids, | ||
input_mask, global_mask, num_global_tokens, random_mask_length) | ||
if len(all_inputs) != test_cases: | ||
print("Failed to create test data for test.") | ||
return all_inputs | ||
|
||
|
||
def create_longformer_test_data(model, output_dir, batch_size, sequence_length, test_cases, seed, verbose, | ||
input_ids_name, input_mask_name, global_mask_name, num_global_tokens): | ||
|
||
input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name) | ||
all_inputs = generate_test_data(batch_size, sequence_length, test_cases, seed, verbose, input_ids, input_mask, | ||
global_mask, num_global_tokens) | ||
|
||
for i, inputs in enumerate(all_inputs): | ||
output_test_data(output_dir, i, inputs) | ||
|
||
|
||
def main(): | ||
args = parse_arguments() | ||
|
||
output_dir = args.output_dir | ||
if output_dir is None: | ||
# Default output directory is a sub-directory under the directory of model. | ||
output_dir = os.path.join( | ||
Path(args.model).parent, "b{}_s{}_g{}".format(args.batch_size, args.sequence_length, args.global_tokens)) | ||
|
||
if output_dir is not None: | ||
# create the output directory if not existed | ||
path = Path(output_dir) | ||
path.mkdir(parents=True, exist_ok=True) | ||
else: | ||
print("Directory existed. test data files will be overwritten.") | ||
|
||
create_longformer_test_data(args.model, output_dir, args.batch_size, args.sequence_length, args.samples, args.seed, | ||
args.verbose, args.input_ids_name, args.input_mask_name, args.global_mask_name, | ||
args.global_tokens) | ||
|
||
print("Test data is saved to directory:", output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |