Skip to content

Commit

Permalink
Support required types when excluding typed registrations (microsoft#…
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 authored Mar 8, 2021
1 parent de6e66f commit b6c4a7a
Showing 8 changed files with 273 additions and 178 deletions.
12 changes: 6 additions & 6 deletions docs/Reduced_Operator_Kernel_build.md
Original file line number Diff line number Diff line change
@@ -2,13 +2,13 @@

In order to reduce the compiled binary size of ONNX Runtime (ORT), the operator kernels included in the build can be reduced to just the kernels required by your model/s.

A configuration file must be created with details of the kernels that are required.
A configuration file must be created with details of the kernels that are required.

Following that, ORT must be manually built, providing the configuration file in the `--include_ops_by_config` parameter. The build process will update the ORT kernel registration source files to exclude the unused kernels.
Following that, ORT must be manually built, providing the configuration file in the `--include_ops_by_config` parameter. The build process will update the ORT kernel registration source files to exclude the unused kernels.

See the [build instructions](https://www.onnxruntime.ai/docs/how-to/build.html#build-instructions) for more details on building ORT.

When building ORT with a reduced set of kernel registrations, `--skip_tests` **MUST** be specified as the kernel reduction will render many of the unit tests invalid.
When building ORT with a reduced set of kernel registrations, `--skip_tests` **MUST** be specified as the kernel reduction will render many of the unit tests invalid.

NOTE: The operator exclusion logic when building with an operator reduction configuration file will only disable kernel registrations each time it runs. It will NOT re-enable previously disabled kernels. If you wish to change the list of kernels included, it is best to revert the repository to a clean state (e.g. via `git reset --hard`) before building ORT again.

@@ -75,7 +75,7 @@ If, for example, the types of inputs 0 and 1 were important, the entry may look
`{"inputs": {"0": ["float", "int32_t"], "1": ["int32_t"]}}`

Finally some operators do non-standard things and store their type information under a 'custom' key.
ai.onnx.OneHot is an example of this, where 3 type names from the inputs are combined into a string.
`{"custom": ["float_int64_t_int64_t", "int64_t_string_int64_t"]}`
ai.onnx.OneHot is an example of this, where the three input types are combined into a triple.
`{"custom": [["float", "int64_t", "int64_t"], ["int64_t", "std::string", "int64_t"]]}`

For these reasons, it is best to generate the configuration file first, and manually edit any entries if needed.
For these reasons, it is best to generate the configuration file first, and manually edit any entries if needed.
Original file line number Diff line number Diff line change
@@ -130,7 +130,9 @@ jobs:
displayName: Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED (globally allowed types)] and run tests
inputs:
script: |
echo "!globally_allowed_types;bool,float,int8_t,uint8_t" \
printf "%s\n%s\n" \
"!globally_allowed_types;bool,float,int8_t,uint8_t" \
"!no_ops_specified_means_all_ops_are_required" \
> $(test_data_directory)/globally_allowed_types.config && \
docker run --rm \
--volume $(Build.SourcesDirectory):/onnxruntime_src \
5 changes: 5 additions & 0 deletions tools/ci_build/op_registration_utils.py
Original file line number Diff line number Diff line change
@@ -184,6 +184,11 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor:
registration_processor.process_registration(lines_to_process, domain, op_type,
int(start_version), int(end_version), type)

else:
log.warning("Ignoring unhandled kernel registration variant: {}".format(code_line))
for line in lines_to_process:
registration_processor.process_other_line(line)

return offset + 1


103 changes: 17 additions & 86 deletions tools/ci_build/reduce_op_kernels.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
import argparse
import op_registration_utils
import os
import re
import shutil
import sys
import typing
@@ -19,74 +18,24 @@
sys.path.append(ort_tools_py_path)

from util import parse_config # noqa
from util.ort_format_model.operator_type_usage_processors import OperatorTypeUsageManager # noqa
from util.ort_format_model.operator_type_usage_processors import OpTypeImplFilterInterface # noqa

log = get_logger("reduce_op_kernels")


# valid C++ scalar types that can be specified as globally allowed types
_valid_allowed_types = {
"bool",
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
"int8_t", "int16_t", "int32_t", "int64_t",
"MLFloat16", "BFloat16", # in onnxruntime namespace
"float", "double",
"string", # in std namespace
}


def _validated_globally_allowed_types(globally_allowed_types: typing.Collection[str]) -> typing.Set[str]:
'''Return a valid set of globally allowed types.'''
# ensure globally_allowed_types is a set
if not isinstance(globally_allowed_types, set):
globally_allowed_types = set(globally_allowed_types)

if not globally_allowed_types <= _valid_allowed_types:
raise ValueError(
"Globally allowed types must be a subset of valid allowed types. Actual: {}, valid: {}".format(
globally_allowed_types, sorted(_valid_allowed_types)))

return globally_allowed_types


def _type_re_from_globally_allowed_types(globally_allowed_types: typing.Set[str]) -> typing.re.Pattern:
'''Return a regular expression to match type registration strings to a set of globally allowed types.'''
# to keep a registration, the type should match patterns like:
# 1. T0
# 2. T0_T1_T2
# where Ti is a member of globally_allowed_types and multiple Ti's are delimited by "_"
# this covers both the common case (1) and special cases like OneHot registration (2)
allowed_type_subpattern = \
"(?:" + "|".join(re.escape(allowed_type) for allowed_type in sorted(globally_allowed_types)) + ")"
return re.compile("^{0}(?:_{0})*$".format(allowed_type_subpattern))


class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcessor):
'''Registration processor that excludes registrations and writes the result to an output file.'''
def __init__(self, required_ops: dict, op_type_usage_manager: typing.Optional[OperatorTypeUsageManager],
globally_allowed_types: typing.Optional[typing.Set[str]], output_file: str):
def __init__(self, required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface],
output_file: str):
self._required_ops = required_ops

if op_type_usage_manager is not None and globally_allowed_types is not None:
raise ValueError("At most one of op_type_usage_manager and globally_allowed_types may be provided.")

self._op_type_usage_manager = op_type_usage_manager

self._enable_all_ops = globally_allowed_types is not None and not required_ops
if self._enable_all_ops:
log.info("No required ops were specified but globally allowed types were specified. "
"Globally allowed types will be used to exclude op implementations.")

self._globally_allowed_types_re = \
_type_re_from_globally_allowed_types(globally_allowed_types) \
if globally_allowed_types is not None else None

self._op_type_impl_filter = op_type_impl_filter
self._output_file = output_file

def _is_op_required(self, domain: str, operator: str,
start_version: int, end_version: typing.Optional[int]) -> typing.Tuple[bool, str]:
'''See if an op should be excluded because it is not required.'''
if self._enable_all_ops:
'''See if an op is required.'''
if self._required_ops is None:
return True

if domain not in self._required_ops:
@@ -116,17 +65,10 @@ def process_registration(self, lines: typing.List[str], constant_for_domain: str
exclude = True
reason = "Entire op is not required."

if not exclude and type is not None:
if self._op_type_usage_manager is not None:
if not self._op_type_usage_manager.is_typed_registration_needed(domain, operator, type):
exclude = True
reason = "Specific typed registration is not required."

elif self._globally_allowed_types_re is not None:
if not self._globally_allowed_types_re.match(type):
exclude = True
reason = "Specific typed registration does not contain globally allowed types."

if not exclude and type is not None and self._op_type_impl_filter is not None:
if not self._op_type_impl_filter.is_typed_registration_needed(domain, operator, type):
exclude = True
reason = "Specific typed registration is not required."
else:
log.warning('Keeping {} registration from unknown domain: {}'
.format(registration_identifier, constant_for_domain))
@@ -152,9 +94,8 @@ def ok(self):

def _process_provider_registrations(
ort_root: str, use_cuda: bool,
required_ops: dict,
op_type_usage_manager: typing.Optional[OperatorTypeUsageManager],
globally_allowed_types: typing.Optional[typing.Set[str]]):
required_ops: typing.Optional[dict],
op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
'''Rewrite provider registration files.'''
kernel_registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)

@@ -169,8 +110,7 @@ def _process_provider_registrations(

# read from backup and overwrite original with commented out lines for any kernels that are not required
with open(kernel_registration_file, 'w') as file_to_write:
processor = _ExcludingRegistrationProcessor(
required_ops, op_type_usage_manager, globally_allowed_types, file_to_write)
processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)

op_registration_utils.process_kernel_registration_file(backup_path, processor)

@@ -231,20 +171,11 @@ def reduce_ops(config_path: str, enable_type_reduction: bool = False, use_cuda:
:param enable_type_reduction: Whether per operator type reduction is enabled
:param use_cuda: Whether to reduce op kernels for the CUDA provider
'''
required_ops, op_type_usage_manager, globally_allowed_types = parse_config(config_path, enable_type_reduction)

if globally_allowed_types is not None:
globally_allowed_types = _validated_globally_allowed_types(globally_allowed_types)
required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction)

_process_provider_registrations(ort_root, use_cuda, required_ops, op_type_usage_manager, globally_allowed_types)
_process_provider_registrations(ort_root, use_cuda, required_ops, op_type_impl_filter)

if op_type_usage_manager is not None:
type_control_cpp_code = op_type_usage_manager.get_cpp_entries()
elif globally_allowed_types is not None:
type_control_cpp_code = ["ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(
", ".join(sorted(globally_allowed_types)))]
else:
type_control_cpp_code = []
type_control_cpp_code = op_type_impl_filter.get_cpp_entries() if op_type_impl_filter is not None else []

_insert_type_control_cpp_code(ort_root, type_control_cpp_code)

5 changes: 4 additions & 1 deletion tools/python/util/ort_format_model/__init__.py
Original file line number Diff line number Diff line change
@@ -12,4 +12,7 @@

from .utils import create_config_from_models # noqa
from .ort_model_processor import OrtFormatModelProcessor # noqa
from .operator_type_usage_processors import OperatorTypeUsageManager # noqa
from .operator_type_usage_processors import ( # noqa
GloballyAllowedTypesOpTypeImplFilter,
OpTypeImplFilterInterface,
OperatorTypeUsageManager)
Loading

0 comments on commit b6c4a7a

Please sign in to comment.