From b654852b155d667a0c86adc8ff92d5eb7ca2c44b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Jul 2024 12:54:05 -0500 Subject: [PATCH 001/202] [Bugfix] Allow import of TVM when current directory is read-only (#17142) * [Bugfix] Allow import of TVM when current directory is read-only Prior to this commit, TVM could only be imported if the current directory had write privileges. This was due to the use of `tvm.contrib.pickle_memoize` to cache the winograd transformation matrices. This commit makes multiple related fixes, to ensure that (1) TVM can be imported regardless of directory permissions, (2) the working directory is not left in a cluttered state, and (3) cache files are generated in an expected location to be reused later. * The cache directory is only generated when required, just prior to saving a cache. * The cache directory defaults to `$HOME/.cache/tvm/pkl_memoize`, rather than `.pkl_memorize_py3` in the working directory. * The cache directory respects `XDG_CACHE_HOME`, using `$XDG_CACHE_HOME/tvm/pkl_memoize` if set. * lint fix --- python/tvm/contrib/pickle_memoize.py | 58 +++++--- tests/python/contrib/pickle_memoize_script.py | 48 +++++++ tests/python/contrib/test_memoize.py | 126 ++++++++++++++++++ 3 files changed, 214 insertions(+), 18 deletions(-) create mode 100755 tests/python/contrib/pickle_memoize_script.py create mode 100644 tests/python/contrib/test_memoize.py diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index 6d2ffbac0673..4f3aff8fb5b0 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Memoize result of function via pickle, used for cache testcases.""" + # pylint: disable=broad-except,superfluous-parens +import atexit import os +import pathlib import sys -import atexit + from decorator import decorate from .._ffi.base import string_types @@ -28,6 +31,17 @@ import pickle +def _get_global_cache_dir() -> pathlib.Path: + if "XDG_CACHE_HOME" in os.environ: + cache_home = pathlib.Path(os.environ.get("XDG_CACHE_HOME")) + else: + cache_home = pathlib.Path.home().joinpath(".cache") + return cache_home.joinpath("tvm", f"pkl_memoize_py{sys.version_info[0]}") + + +GLOBAL_CACHE_DIR = _get_global_cache_dir() + + class Cache(object): """A cache object for result cache. @@ -42,28 +56,36 @@ class Cache(object): cache_by_key = {} def __init__(self, key, save_at_exit): - cache_dir = f".pkl_memoize_py{sys.version_info[0]}" - try: - os.mkdir(cache_dir) - except FileExistsError: - pass - else: - self.cache = {} - self.path = os.path.join(cache_dir, key) - if os.path.exists(self.path): - try: - self.cache = pickle.load(open(self.path, "rb")) - except Exception: - self.cache = {} - else: - self.cache = {} + self._cache = None + + self.path = GLOBAL_CACHE_DIR.joinpath(key) self.dirty = False self.save_at_exit = save_at_exit + @property + def cache(self): + """Return the cache, initializing on first use.""" + + if self._cache is not None: + return self._cache + + if self.path.exists(): + with self.path.open("rb") as cache_file: + try: + cache = pickle.load(cache_file) + except pickle.UnpicklingError: + cache = {} + else: + cache = {} + + self._cache = cache + return self._cache + def save(self): if self.dirty: - print(f"Save memoize result to {self.path}") - with open(self.path, "wb") as out_file: + self.path.parent.mkdir(parents=True, exist_ok=True) + + with self.path.open("wb") as out_file: pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL) diff --git a/tests/python/contrib/pickle_memoize_script.py b/tests/python/contrib/pickle_memoize_script.py new file mode 100755 index 000000000000..f0d73e391066 --- /dev/null +++ b/tests/python/contrib/pickle_memoize_script.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import tvm + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_save_data", save_at_exit=True) +def get_data_saved(): + return 42 + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_transient_data", save_at_exit=False) +def get_data_transient(): + return 42 + + +def main(): + assert len(sys.argv) == 3, "Expect arguments SCRIPT NUM_SAVED NUM_TRANSIENT" + + num_iter_saved = int(sys.argv[1]) + num_iter_transient = int(sys.argv[2]) + + for _ in range(num_iter_saved): + get_data_saved() + for _ in range(num_iter_transient): + get_data_transient() + + +if __name__ == "__main__": + main() diff --git a/tests/python/contrib/test_memoize.py b/tests/python/contrib/test_memoize.py new file mode 100644 index 000000000000..6881940e5062 --- /dev/null +++ b/tests/python/contrib/test_memoize.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for tvm.contrib.pickle_memoize""" + +import os +import pathlib +import tempfile +import subprocess +import sys + +import tvm.testing + +TEST_SCRIPT_FILE = pathlib.Path(__file__).with_name("pickle_memoize_script.py").resolve() + + +def test_cache_dir_not_in_current_working_dir(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + subprocess.check_call([TEST_SCRIPT_FILE, "1", "1"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + +def test_current_directory_is_not_required_to_be_writable(): + """TVM may be imported without directory permissions + + This is a regression test. In previous implementations, the + `tvm.contrib.pickle_memoize.memoize` function would write to the + current directory when importing TVM. Import of a Python module + should not write to any directory. + + """ + + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # User may read/cd into the temp dir, nobody may write to temp + # dir. + temp_dir.chmod(0o500) + subprocess.check_call([sys.executable, "-c", "import tvm"], cwd=temp_dir) + + +def test_cache_dir_defaults_to_home_config_cache(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + subprocess.check_call([TEST_SCRIPT_FILE, "1", "0"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = pathlib.Path.home().joinpath(".cache", "tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) >= 1 + + +def test_cache_dir_respects_xdg_cache_home(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "1", "0"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + new_files = list(temp_working_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) == 1 + + +def test_cache_dir_only_created_when_used(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "0", "1"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert not cache_dir.exists() + + +if __name__ == "__main__": + tvm.testing.main() From 70c53082e6715516aefefcdca6262e195f36a0de Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 17 Jul 2024 02:34:19 +0800 Subject: [PATCH 002/202] [Relax] Fix fuseOps via pattern (#17160) fix fuseops via pattern --- src/relax/transform/fuse_ops.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 45d70fc3e290..2be7ad41f3e1 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1222,7 +1222,12 @@ class CompositeFunctionAnnotator : public ExprMutator { IRModule Run() { auto mod = builder_->GetContextIRModule(); for (const auto& gv : mod->GetGlobalVars()) { - const auto& base_func = mod->Lookup(gv); + auto it = mod->functions.find(gv); + // Note that the fusion pass may have already removed the function. + if (it == mod->functions.end()) { + continue; + } + const auto& base_func = (*it).second; if (const auto* func = base_func.as()) { if (func->GetAttr(attr::kComposite).defined() || func->GetAttr(attr::kCodegen).defined()) { @@ -1399,7 +1404,7 @@ Pass FuseOps(int fuse_opt_level) { }; return CreateModulePass(/*pass_function=*/pass_func, // /*opt_level=*/0, // - /*pass_name=*/"FuseOps", // + /*name=*/"FuseOps", // /*required=*/{}); } @@ -1412,9 +1417,9 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen, entry_function_names); }; - return CreateModulePass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"FuseOpsByPattern", // + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*name=*/"FuseOpsByPattern", // /*required=*/{}); } From 51d7c5e47a108b7d03036e6a1045aa8348f9562c Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Subramaniam Date: Wed, 17 Jul 2024 01:52:00 +0530 Subject: [PATCH 003/202] [Hexagon] Support RPC execution of existing shared lib (#17162) This patch modifies the `get_executor_from_factory` for relax to support accepting a string that points to an already exported shared library. This allows us to run models that were already compiled through the RPC executor. --- python/tvm/contrib/hexagon/session.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 9f1166823423..50064e42ba08 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -287,14 +287,14 @@ def get_graph_debug_executor( ) def get_executor_from_factory( - self, module: Union[ExecutorFactoryModule, relax.Executable], hexagon_arch: str = "v68" + self, module: Union[ExecutorFactoryModule, relax.Executable, str], hexagon_arch: str = "v68" ): """Create a local GraphModule which consumes a remote libmod. Parameters ---------- - module : Union[ExecutorFactoryModule, relax.Executable] + module : Union[ExecutorFactoryModule, relax.Executable, str] The module to upload to the remote session and load. @@ -305,7 +305,7 @@ def get_executor_from_factory( return self._aot_executor_from_factory(module) if isinstance(module, GraphExecutorFactoryModule): return self._graph_executor_from_factory(module) - if isinstance(module, relax.Executable): + if isinstance(module, (relax.Executable, str)): return self._relax_vm_executable_executor(module, hexagon_arch=hexagon_arch) raise TypeError(f"Unsupported executor type: {type(module)}") @@ -358,7 +358,9 @@ def _graph_executor_from_factory( """ return self.get_graph_executor(module.get_graph_json(), module.get_lib()) - def _relax_vm_executable_executor(self, vm_exec: relax.Executable, hexagon_arch: str): + def _relax_vm_executable_executor( + self, vm_exec: Union[relax.Executable, str], hexagon_arch: str + ): """Create a local TVM module which consumes a remote vm executable. Paramters @@ -366,7 +368,7 @@ def _relax_vm_executable_executor(self, vm_exec: relax.Executable, hexagon_arch: vm_exec : relax.Executable The Relax VM Executable to upload to the remote and load. This will typically be the - output of `relax.build`. + output of `relax.build` or the path to an already built and exported shared library hexagon_arch : str The hexagon arch to be used Returns @@ -376,14 +378,21 @@ def _relax_vm_executable_executor(self, vm_exec: relax.Executable, hexagon_arch: """ assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" - temp_dir = utils.tempdir() - path_exec = temp_dir.relpath("exec.so") + if isinstance(vm_exec, relax.Executable): + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") - vm_exec.mod.export_library( - path_exec, - fcompile=hexagon.create_aot_shared, - hexagon_arch=hexagon_arch, - ) + vm_exec.mod.export_library( + path_exec, + fcompile=hexagon.create_aot_shared, + hexagon_arch=hexagon_arch, + ) + + path = self.upload(path_exec, "exec.so") + elif isinstance(vm_exec, str): + path_exec = vm_exec + else: + raise TypeError(f"Unsupported executor type: {type(vm_exec)}") path = self.upload(path_exec, "exec.so") return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) From 73078f11dcdc383246fefa50961a6a9bda6137cf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Jul 2024 16:34:24 -0500 Subject: [PATCH 004/202] [CI] Remove lint step from `unity/pr-head` step (#17155) * [CI] Remove lint step from `unity/pr-head` step This step should only be performed as part of the `lint/pr-head` CI step. It was included as part of the unity-specific CI steps prior to merging of unity into main. It is no longer necessary as part of `unity/pr-head`. * Revert the task_extra_lint.sh removal --- ci/jenkins/unity_jenkinsfile.groovy | 8 -------- 1 file changed, 8 deletions(-) mode change 100644 => 100755 ci/jenkins/unity_jenkinsfile.groovy diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy old mode 100644 new mode 100755 index b9047e8b6f64..9b4f0009e344 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -210,14 +210,6 @@ def lint(node_type) { ) skip_ci = should_skip_ci(env.CHANGE_ID) skip_slow_tests = should_skip_slow_tests(env.CHANGE_ID) - sh( - script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", - label: 'Run lint', - ) - sh( - script: "${docker_run} ${ci_lint} ./tests/scripts/unity/task_extra_lint.sh", - label: 'Run extra lint', - ) } } } From 22a89785bab2e120bb089a2d617342db0d157bc7 Mon Sep 17 00:00:00 2001 From: Cookiee235 Date: Thu, 18 Jul 2024 21:49:11 +0800 Subject: [PATCH 005/202] [Relax][BugFix] Fix a bug about the IR construction in test file (#17121) Update test_transform_dead_code_elimination.py Fix the wrong Relax IR construction --- tests/python/relax/test_transform_dead_code_elimination.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 0cb0d4624731..142faf51607b 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -454,7 +454,7 @@ def main( R.output(lv0) gv_x = R.astype(x, dtype="float16") - gv_w = R.astype(x, dtype="float16") + gv_w = R.astype(w, dtype="float16") with R.dataflow(): lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims( @@ -481,7 +481,7 @@ def main( w: R.Tensor((4, 3, 3, 3), dtype="float32"), ): gv_x = R.astype(x, dtype="float16") - gv_w = R.astype(x, dtype="float16") + gv_w = R.astype(w, dtype="float16") with R.dataflow(): lv1: R.Tensor((2, 28, 28, 3), dtype="float16") = R.permute_dims( From 70d86e3fb7adf2afc05797e749b62a1d9c6c788a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 18 Jul 2024 22:49:45 +0900 Subject: [PATCH 006/202] [Meta Schedule][XGBoost] enable custom callback func test with xgboost>=1.6.0 (#17168) enable callback func test with xgboost>=1.6.0 --- .../test_meta_schedule_cost_model.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/python/meta_schedule/test_meta_schedule_cost_model.py b/tests/python/meta_schedule/test_meta_schedule_cost_model.py index 0e1b2f64216b..dadedcf601aa 100644 --- a/tests/python/meta_schedule/test_meta_schedule_cost_model.py +++ b/tests/python/meta_schedule/test_meta_schedule_cost_model.py @@ -257,17 +257,6 @@ def test_meta_schedule_xgb_model_reupdate(): model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) -def xgb_version_check(): - - # pylint: disable=import-outside-toplevel - import xgboost as xgb - from packaging import version - - # pylint: enable=import-outside-toplevel - return version.parse(xgb.__version__) >= version.parse("1.6.0") - - -@unittest.skipIf(xgb_version_check(), "test not supported for xgboost version after 1.6.0") def test_meta_schedule_xgb_model_callback_as_function(): # pylint: disable=import-outside-toplevel from itertools import chain as itertools_chain @@ -330,14 +319,12 @@ def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignor num_boost_round=10000, obj=obj, callbacks=[ - partial( - _get_custom_call_back( - early_stopping_rounds=model.early_stopping_rounds, - verbose_eval=model.verbose_eval, - fevals=[rmse, avg_peak_score], - evals=[(d_train.dmatrix, "tr")], - cvfolds=None, - ) + _get_custom_call_back( + early_stopping_rounds=model.early_stopping_rounds, + verbose_eval=model.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(d_train.dmatrix, "tr")], + cvfolds=None, ) ], ) From d006ecac35fd3100ee547d2d0356e21245a93ed0 Mon Sep 17 00:00:00 2001 From: tsu-bin <81693503+tsu-bin@users.noreply.github.com> Date: Thu, 18 Jul 2024 21:50:14 +0800 Subject: [PATCH 007/202] [Relax] [ONNX] Add support for Sign and Not (#17167) Co-authored-by: tsu-bin --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++ tests/python/relax/test_frontend_onnx.py | 8 ++++++++ 2 files changed, 26 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3a70cd090a54..85d4402d6640 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1948,6 +1948,22 @@ def _impl_v14(cls, bb, inputs, attr, params): ) +class Sign(OnnxOpConverter): + """Converts an onnx Sign node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.sign(inputs[0]) + + +class Not(OnnxOpConverter): + """Converts an onnx Not node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.logical_not(inputs[0]) + + def _get_convert_map(): return { "MatMul": MatMul, @@ -2030,6 +2046,8 @@ def _get_convert_map(): "Elu": Elu, "HardSigmoid": HardSigmoid, "HardSwish": HardSwish, + "Sign": Sign, + "Not": Not, } diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0fc7ec064402..05316f2699dd 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -600,6 +600,14 @@ def test_hardswish(): verify_unary("HardSwish", [32, 32]) +def test_sign(): + verify_unary("Sign", [32, 32]) + + +def test_not(): + verify_unary("Not", [32, 32], dtype=TensorProto.BOOL) + + def test_conv(): def _verify_conv(input_shape, weight_shape, output_shape): bias_shape = [output_shape[1]] From 070546eb4afddab5725dd145358931e9dfcb90f4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 18 Jul 2024 16:12:13 -0500 Subject: [PATCH 008/202] [TVMJS] Check DataType.NUMPY2STR when saving array (#17174) Prior to this commit, the `dtype` string used by `tvmjs.dump_ndarray_cache` was generated as `str(np_array.dtype)`. While this works in most cases, there are a few naming differences between TVM datatypes and numpy datatypes, such as `"float8_e4m3fn"` in Numpy being equivalent to `"e4m3_float8"` in TVM. This commit updates `dump_ndarray_cache` to check `DataType.NUMPY2STR` for the datatype string, allowing round-trip save/load of float8 arrays. --- python/tvm/contrib/tvmjs.py | 9 ++++- tests/python/contrib/test_tvmjs.py | 64 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/python/contrib/test_tvmjs.py diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 2a7604c0ada2..9bff724df7bc 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -35,6 +35,7 @@ import tvm from tvm._ffi.libinfo import find_lib_path +from tvm.runtime import DataType from .emcc import create_tvmjs_wasm @@ -276,7 +277,13 @@ def dump_ndarray_cache( v = v.numpy() # prefer to preserve original dtype, especially if the format was bfloat16 - dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype) + dtype = origin_v.dtype if isinstance(origin_v, tvm.nd.NDArray) else v.dtype + + if dtype in DataType.NUMPY2STR: + dtype = DataType.NUMPY2STR[dtype] + else: + dtype = str(dtype) + total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize # convert fp32 to bf16 diff --git a/tests/python/contrib/test_tvmjs.py b/tests/python/contrib/test_tvmjs.py new file mode 100644 index 000000000000..22742ec224ef --- /dev/null +++ b/tests/python/contrib/test_tvmjs.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test contrib.tvmjs""" + +import tempfile + +import numpy as np +import pytest + +import tvm.testing +from tvm.contrib import tvmjs + +dtype = tvm.testing.parameter( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "float8_e4m3fn", + "float8_e5m2", +) + + +def test_save_load_float8(dtype): + if "float8" in dtype or "bfloat16" in dtype: + ml_dtypes = pytest.importorskip("ml_dtypes") + np_dtype = np.dtype(getattr(ml_dtypes, dtype)) + else: + np_dtype = np.dtype(dtype) + + arr = np.arange(16, dtype=np_dtype) + + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + tvmjs.dump_ndarray_cache({"arr": arr}, temp_dir) + cache, _ = tvmjs.load_ndarray_cache(temp_dir, tvm.cpu()) + + after_roundtrip = cache["arr"].numpy() + + np.testing.assert_array_equal(arr, after_roundtrip) + + +if __name__ == "__main__": + tvm.testing.main() From 3c7adfb1f7015078903ba53cc5317ead1b4f5f32 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 20 Jul 2024 04:00:01 +0900 Subject: [PATCH 009/202] Use `packaging.version.parse` instead of `distutils.version.LooseVersion` (#17173) use `packaging.version.parse` instead of `distutils.version.LooseVersion` --- python/tvm/contrib/msc/core/utils/info.py | 6 +++--- python/tvm/relay/frontend/pytorch_utils.py | 4 ++-- python/tvm/relay/op/contrib/ethosn.py | 6 +++--- python/tvm/relay/testing/tflite.py | 4 ++-- .../test_arm_compute_lib/test_network.py | 4 ++-- .../frontend/tensorflow/test_forward.py | 9 ++++----- tests/python/frontend/tflite/test_forward.py | 19 +++++++++---------- 7 files changed, 25 insertions(+), 27 deletions(-) diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 4fea45f8fab2..58b08112797a 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -17,7 +17,7 @@ """tvm.contrib.msc.core.utils.info""" from typing import List, Tuple, Dict, Any, Union -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import tvm @@ -409,8 +409,8 @@ def get_version(framework: str) -> List[int]: raw_version = "1.0.0" except: # pylint: disable=bare-except raw_version = "1.0.0" - raw_version = raw_version or "1.0.0" - return LooseVersion(raw_version).version + version = parse(raw_version or "1.0.0") + return [version.major, version.minor, version.micro] def compare_version(given_version: List[int], target_version: List[int]) -> int: diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 7de1248bda77..8686be4b1ea9 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -36,7 +36,7 @@ def is_version_greater_than(ver): than the one given as an argument. """ import torch - from distutils.version import LooseVersion + from packaging.version import parse torch_ver = torch.__version__ # PT version numbers can include +cu[cuda version code] @@ -44,7 +44,7 @@ def is_version_greater_than(ver): if "+cu" in torch_ver: torch_ver = torch_ver.split("+cu")[0] - return LooseVersion(torch_ver) > ver + return parse(torch_ver) > parse(ver) def getattr_attr_name(node): diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 81534d48a216..c1e87ad5d90b 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Arm(R) Ethos(TM)-N NPU supported operators.""" from enum import Enum -from distutils.version import LooseVersion +from packaging.version import parse import tvm.ir from tvm.relay import transform @@ -118,7 +118,7 @@ def partition_for_ethosn(mod, params=None, **opts): """ api_version = ethosn_api_version() supported_api_versions = ["3.2.0"] - if all(api_version != LooseVersion(exp_ver) for exp_ver in supported_api_versions): + if all(parse(api_version) != parse(exp_ver) for exp_ver in supported_api_versions): raise ValueError( f"Driver stack version {api_version} is unsupported. " f"Please use version in {supported_api_versions}." @@ -433,7 +433,7 @@ def split(expr): """Check if a split is supported by Ethos-N.""" if not ethosn_available(): return False - if ethosn_api_version() == LooseVersion("3.0.1"): + if parse(ethosn_api_version()) == parse("3.0.1"): return False if not _ethosn.split(expr): return False diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index df9c0bcadf62..29f6bc62cad2 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Common utilities for creating TFLite models""" -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import pytest import tflite.Model # pylint: disable=wrong-import-position @@ -134,7 +134,7 @@ def generate_reference_data(self): assert self.serial_model is not None, "TFLite model was not created." output_tolerance = None - if tf.__version__ < LooseVersion("2.5.0"): + if parse(tf.__version__) < parse("2.5.0"): output_tolerance = 1 interpreter = tf.lite.Interpreter(model_content=self.serial_model) else: diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 3cf81e971f77..8c6302abf842 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -16,7 +16,7 @@ # under the License. """Arm Compute Library network tests.""" -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import pytest @@ -137,7 +137,7 @@ def get_model(): mod, params = _get_keras_model(mobilenet, inputs) return mod, params, inputs - if keras.__version__ < LooseVersion("2.9"): + if parse(keras.__version__) < parse("2.9"): # This can be removed after we migrate to TF/Keras >= 2.9 expected_tvm_ops = 56 expected_acl_partitions = 31 diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index db270ccb2e9f..354ed38a62ce 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -21,7 +21,6 @@ This article is a test script to test tensorflow operator with Relay. """ from __future__ import print_function -from distutils.version import LooseVersion import threading import platform @@ -1755,7 +1754,7 @@ def _test_concat_v2(shape1, shape2, dim): def test_forward_concat_v2(): - if tf.__version__ < LooseVersion("1.4.1"): + if package_version.parse(tf.__version__) < package_version.parse("1.4.1"): return _test_concat_v2([2, 3], [2, 3], 0) @@ -3128,7 +3127,7 @@ def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype) def test_forward_clip_by_value(): """test ClipByValue op""" - if tf.__version__ < LooseVersion("1.9"): + if package_version.parse(tf.__version__) < package_version.parse("1.9"): _test_forward_clip_by_value((4,), 0.1, 5.0, "float32") _test_forward_clip_by_value((4, 4), 1, 5, "int32") @@ -4482,7 +4481,7 @@ def _test_forward_zeros_like(in_shape, dtype): def test_forward_zeros_like(): - if tf.__version__ < LooseVersion("1.2"): + if package_version.parse(tf.__version__) < package_version.parse("1.2"): _test_forward_zeros_like((2, 3), "int32") _test_forward_zeros_like((2, 3, 5), "int8") _test_forward_zeros_like((2, 3, 5, 7), "uint16") @@ -5566,7 +5565,7 @@ def test_forward_spop(): # This test is expected to fail in TF version >= 2.6 # as the generated graph will be considered frozen, hence # not passing the criteria for the test below. - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.__version__) < package_version.parse("2.6.1"): _test_spop_resource_variables() # Placeholder test cases diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 75a2a37c636a..cb0b17ea3fcf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -22,7 +22,6 @@ """ from __future__ import print_function from functools import partial -from distutils.version import LooseVersion import platform import os import tempfile @@ -1054,7 +1053,7 @@ def representative_data_gen(): input_node = subgraph.Tensors(model_input).Name().decode("utf-8") tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): input_node = data_in.name.replace(":0", "") else: input_node = "serving_default_" + data_in.name + ":0" @@ -1775,7 +1774,7 @@ def representative_data_gen(): tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): input_node = data_in.name.replace(":0", "") else: input_node = "serving_default_" + data_in.name + ":0" @@ -2219,9 +2218,9 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8): tflite_output = run_tflite_graph(tflite_model_quant, data) # TFLite 2.6.x upgrade support - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.__version__) < package_version.parse("2.6.1"): in_node = ["serving_default_input_int8"] - elif tf.__version__ < LooseVersion("2.9"): + elif package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ( ["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"] ) @@ -2245,7 +2244,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): """One iteration of rsqrt""" # tensorflow version upgrade support - if tf.__version__ < LooseVersion("2.6.1") or not quantized: + if package_version.parse(tf.__version__) < package_version.parse("2.6.1") or not quantized: return _test_unary_elemwise( math_ops.rsqrt, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype ) @@ -2254,7 +2253,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ["tfl.quantize"] else: in_node = "serving_default_input" @@ -2338,7 +2337,7 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8): tf.math.cos, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ["tfl.quantize"] else: in_node = "serving_default_input" @@ -3396,7 +3395,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = data_in.name.split(":")[0] else: in_node = "serving_default_" + data_in.name + ":0" @@ -3426,7 +3425,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = data_in.name.split(":")[0] else: in_node = "serving_default_" + data_in.name + ":0" From e5bf56d1f4d4d46cfe4845e4f76c991be35cc332 Mon Sep 17 00:00:00 2001 From: arangasa <76030063+arangasa@users.noreply.github.com> Date: Mon, 22 Jul 2024 12:12:08 +0530 Subject: [PATCH 010/202] =?UTF-8?q?[Relay][FQ2I]:=20Use=20appropriate=20dt?= =?UTF-8?q?ype=20while=20quantizing=20relay.op.nn.pad=E2=80=A6=20(#17177)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Relay][FQ2I]: Use appropriate dtype while quantizing relay.op.nn.pad's constant pad value * Keep default axis --- .../transform/fake_quantization_to_integer.py | 2 +- .../test_pass_fake_quantization_to_integer.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index b27fc3cba799..7ad838895c9f 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -466,7 +466,7 @@ def pad(expr, type_map): # If the pad-value is a constant, we need to quantize it assert isinstance(pad_value, relay.expr.Constant) assert pad_value.checked_type.dtype in ["float32", "float64", "float16", "bfloat16"] - pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point) + pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point, out_dtype=t.dtype) out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs) return [out, t] diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 6edb3949d683..c0b61f72d1d3 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -814,6 +814,20 @@ def test_fake_quantize_pad(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_pad_with_float_min(): + in_shape = [1, 383, 128] + x = relay.var("x", shape=in_shape, dtype="float32") + op = relay.qnn.quantize(x, relay.const(1.0), relay.const(0), out_dtype="uint8") + op = relay.qnn.dequantize(op, relay.const(1.0), relay.const(0), out_dtype="float32") + op = relay.op.nn.pad( + op, pad_width=[[0, 0], [0, 1], [0, 0]], pad_value=relay.const(-3.40282e38, dtype="float32") + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype="uint8") + x_np = np.random.randint(0, 256, size=in_shape) + x_as_float = x_np.astype("float32") + compare_fq_to_int(op, [x_as_float], True) + + def test_fake_quantize_depth_to_space(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") From 18ff9ff89b4617d8925ef6afde233e8d1742a5bd Mon Sep 17 00:00:00 2001 From: YXY-0922 <50567910+YXY-0922@users.noreply.github.com> Date: Tue, 23 Jul 2024 02:48:57 +0800 Subject: [PATCH 011/202] [MetaSchedule]Add a testcase for padded conv2d in meta_schedule (#17171) ### Bug Fix In the `TileWithTensorIntrin` function, when the `allow_padding` parameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined. --------- Co-authored-by: yuxiyue --- src/tir/schedule/transform.cc | 4 +- ...test_meta_schedule_schedule_rule_mlt_tc.py | 152 ++++++++++++++++++ 2 files changed, 155 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 8f912c59ea16..fec214fa1fc7 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -340,7 +340,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block } auto consumers = sch->GetConsumers(block_rv); for (const auto& consumer : consumers) { - sch->ComputeInline(consumer); + auto sref = sch->GetSRef(consumer); + if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) + sch->ComputeInline(consumer); } } // Construct a mapping from tir loops back to LoopRVs diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index df8607e55127..1fd2ab84749e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( ) +def test_padded_conv(): + # fmt: off + @T.prim_func + def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator") + PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared") + weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared") + PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a") + weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(10): + for ax0_ax1_fused in range(28672): + with T.block("PadInput_reindex_pad_shared"): + v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16) + v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16) + T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3]) + T.writes(PadInput_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0)) + for ax0_ax1_fused in range(512): + with T.block("weight_reindex_pad_shared"): + v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) + T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1]) + T.writes(weight_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0)) + for ax2_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(14, 1): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0) + v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0) + T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0) + T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1): + with T.block("conv2d_nhwc_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2) + T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(14): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2_o = T.axis.spatial(14, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512) + v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2 = T.axis.spatial(14, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [7, 1, 8, 7, 2]), + ("SamplePerfectTile", [2, 1, 1, 2, 1]), + ("SamplePerfectTile", [10, 1, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + 1, + 224, + 224, + 3, + 64, + 7, + 2, + 3, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_conv2d_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": tvm.testing.main() From 5d5edd2fd8b891bb74681f83095d606739cadfcb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Jul 2024 12:36:06 -0700 Subject: [PATCH 012/202] [Relax] Integrate cuDNN attention (#17157) * [Relax] Integrate cuDNN attention * update cmake * lint * lint * cudnn frontend * lint * lint * fix test * skip test --- cmake/config.cmake | 7 + cmake/modules/CUDA.cmake | 16 ++ python/tvm/contrib/cutlass/build.py | 32 +-- python/tvm/contrib/cutlass/gen_tensor_op.py | 4 +- python/tvm/relax/backend/contrib/cudnn.py | 99 ++++++- python/tvm/relax/backend/contrib/cutlass.py | 18 +- python/tvm/relax/backend/patterns.py | 32 ++- python/tvm/relax/frontend/nn/op.py | 9 +- python/tvm/relax/testing/__init__.py | 1 + python/tvm/relax/testing/attention.py | 148 ++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/attention_python.py | 122 ++++++++ src/relax/backend/contrib/cudnn/codegen.cc | 47 +++ src/relax/transform/allocate_workspace.cc | 9 +- src/relax/transform/fuse_ops.cc | 19 +- .../contrib/cudnn/cudnn_frontend/attention.cc | 124 ++++++++ .../contrib/cudnn/cudnn_frontend/attention.h | 83 ++++++ .../contrib/cudnn/cudnn_json_runtime.cc | 267 +++++++++++------- tests/python/relax/test_codegen_cudnn.py | 65 ++++- tests/python/relax/test_codegen_cutlass.py | 213 ++++---------- .../test_transform_allocate_workspace.py | 3 +- ...est_transform_merge_composite_functions.py | 5 +- 22 files changed, 1010 insertions(+), 314 deletions(-) create mode 100644 python/tvm/relax/testing/attention.py create mode 100644 python/tvm/topi/testing/attention_python.py create mode 100644 src/runtime/contrib/cudnn/cudnn_frontend/attention.cc create mode 100644 src/runtime/contrib/cudnn/cudnn_frontend/attention.h diff --git a/cmake/config.cmake b/cmake/config.cmake index 416eec0dcb81..26d50630f7d3 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -245,6 +245,13 @@ set(USE_EDGETPU OFF) # - /path/to/cudnn: use specific path to cuDNN path set(USE_CUDNN OFF) +# Whether use cuDNN frontend +# Possible values: +# - ON: enable cuDNN frontend +# - /path/to/cudnn_frontend: use specific path to cuDNN frontend +# - OFF: disable cuDNN frontend +set(USE_CUDNN_FRONTEND OFF) + # Whether use cuBLAS set(USE_CUBLAS OFF) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index b7b405f82286..ad83ebe26b8c 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -77,6 +77,22 @@ if(USE_CUDA) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY}) endif(USE_CUDNN) + if (USE_CUDNN_FRONTEND) + message(STATUS "Build with cuDNN Frontend support") + if (IS_DIRECTORY ${USE_CUDNN_FRONTEND}) + find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h HINTS ${USE_CUDNN_FRONTEND}/include) + include_directories(SYSTEM ${USE_CUDNN_FRONTEND}/include) + else() + find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h) + endif() + if (NOT CUDNN_FRONTEND_HEADER) + message(FATAL_ERROR "Cannot find cudnn_frontend.h, please set USE_CUDNN_FRONTEND to the path of the cuDNN frontend header") + endif() + tvm_file_glob(GLOB CONTRIB_CUDNN_FRONTEND_SRCS src/runtime/contrib/cudnn/cudnn_frontend/*.cc) + set_property(SOURCE ${CONTRIB_CUDNN_SRCS} APPEND PROPERTY COMPILE_DEFINITIONS TVM_USE_CUDNN_FRONTEND=1) + list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_FRONTEND_SRCS}) + endif(USE_CUDNN_FRONTEND) + if(USE_CUBLAS) message(STATUS "Build with cuBLAS support") tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 1c0a30c62d91..5c09c79bd906 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -868,34 +868,26 @@ def handle_attention(self, f, op_type): signature = _extract_relax_function_signature(f) if _get_call_node(f.body, "relax.nn.attention") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention") + op_attrs = attention_node.attrs elif _get_call_node(f.body, "relax.nn.attention_bias") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention_bias") + op_attrs = attention_node.attrs elif _get_call_node(f.body, "relax.nn.attention_var_len") is not None: - op_attrs = _get_call_node(f.body, "relax.nn.attention_var_len").attrs + attention_node = _get_call_node(f.body, "relax.nn.attention_var_len") + op_attrs = attention_node.attrs else: raise ValueError("Cannot find call node for attention") arg = {} if "stacked_attention" in op_type: - arg["arg0_shape"] = signature["arg0_shape"] arg["arg0_dtype"] = signature["arg0_dtype"] - arg["arg1_shape"] = q_shape = signature["arg1_shape"] - - if "arg3_shape" not in signature: - # arg0: qkv, arg1: shape, arg2: workspace - arg["arg2_shape"] = k_shape = signature["arg1_shape"] - arg["arg3_shape"] = v_shape = signature["arg1_shape"] - else: - # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: workspace - arg["arg2_shape"] = k_shape = signature["arg2_shape"] - arg["arg3_shape"] = v_shape = signature["arg3_shape"] - - if "arg5_dtype" in signature: - # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: bias, arg5: workspace - arg["bias_dtype"] = signature["arg4_dtype"] - if "arg5_shape" in signature: - arg["bias_shape"] = signature["arg4_shape"] + q_shape = get_const_tuple(attention_node.args[0].struct_info.shape) + k_shape = get_const_tuple(attention_node.args[1].struct_info.shape) + v_shape = get_const_tuple(attention_node.args[2].struct_info.shape) + if len(attention_node.args) == 4: + arg["bias_shape"] = get_const_tuple(attention_node.args[3].struct_info.shape) + arg["bias_dtype"] = attention_node.args[3].struct_info.dtype qkv_layout = "qkv_stacked" else: diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2f21a1d313e2..5d04cf13e693 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -745,8 +745,8 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["qkv"] = func_args[0] attrs["num_queries"] = s = annotations["num_queries"] attrs["num_keys"] = annotations["num_keys"] - if len(func_args) > 5 and not is_var_len: # +1 for workspace, the last arg - attrs["bias"] = func_args[4] + if len(func_args) > 2 and not is_var_len: # +1 for workspace, the last arg + attrs["bias"] = func_args[1] else: raise NotImplementedError() diff --git a/python/tvm/relax/backend/contrib/cudnn.py b/python/tvm/relax/backend/contrib/cudnn.py index f730d4e5be0a..2f15e3a4fd19 100644 --- a/python/tvm/relax/backend/contrib/cudnn.py +++ b/python/tvm/relax/backend/contrib/cudnn.py @@ -16,11 +16,16 @@ # under the License. """Pattern table for cuDNN backend""" -from tvm.relax import transform +import operator +from functools import partial, reduce + +import tvm +from tvm import relax +from tvm.relax import PyExprMutator, expr_functor, transform from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_conv2d_pattern +from ..patterns import make_conv2d_pattern, make_stacked_attention_pattern from ..utils import has_leaking_intermediate_variables @@ -60,6 +65,29 @@ def _check_conv2d(context: PatternCheckContext) -> bool: return True +def _check_stacked_attention(context: PatternCheckContext, layout: str) -> bool: + """Check if the given stacked attention workload can be offloaded to cuDNN.""" + if has_leaking_intermediate_variables(context): + return False + if layout == "BS3NH": + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 2: + return False + elif layout == "SBN3H": + if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4: + return False + if "split" in context.annotated_expr: + split_op = context.annotated_expr["split"] + if not split_op.attrs.axis == 3: + return False + else: + raise NotImplementedError(f"Unsupported layout: {layout}") + return True + + register_patterns( [ ( @@ -84,6 +112,16 @@ def _check_conv2d(context: PatternCheckContext) -> bool: ), _check_conv2d, ), + ( + "cudnn.attention.BS3NH", + *make_stacked_attention_pattern(start_op="split", layout="BS3NH"), + partial(_check_stacked_attention, layout="BS3NH"), + ), + ( + "cudnn.attention.SBN3H", + *make_stacked_attention_pattern(start_op="split", layout="SBN3H"), + partial(_check_stacked_attention, layout="SBN3H"), + ), ] ) @@ -105,4 +143,59 @@ def partition_for_cudnn(mod): """ patterns = get_patterns_with_prefix("cudnn") - return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) + return tvm.transform.Sequential( + [ + transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True), + annotate_workspace, + transform.AllocateWorkspace(), + ] + )(mod) + + +def _shape_1d(shape): + return reduce(operator.mul, shape, 1) + + +@expr_functor.mutator +class WorkspaceAnnotator(PyExprMutator): + """Annotate a workspace requirement for each cuDNN-offloaded function.""" + + def __init__(self, mod): + super().__init__(mod) + + def visit_function_(self, f): + if "Composite" not in f.attrs: + body = super().visit_expr(f.body) + new_f = relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) + + if "global_symbol" in f.attrs and "cudnn" in f.attrs["global_symbol"]: + composite_func = body.blocks[0].bindings[0].value + if "WorkspaceSize" in composite_func.attrs: + return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"]) + + return new_f + + if "attention" in f.attrs["Composite"] and "cudnn" in f.attrs["Composite"]: + # Workspace is needed only for larger head sizes, but for simplicity we always allocate. + out_dtype = f.ret_struct_info.dtype + out_size_1d = _shape_1d(f.ret_struct_info.shape) + # This needs to be in sync with the actual value that the kernel expects. + workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] + if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)): + # Tempororay workaround for dynamic shape workload. Will be removed when + # workspace for dynamic shape workload is implemented. + workspace_size_bytes = 8 + return f.with_attr("WorkspaceSize", workspace_size_bytes) + + return f + + +@tvm.transform.module_pass(opt_level=0) +def annotate_workspace(mod, _): + """Pass to annotate a workspace requirement for each cuDNN-offloaded function.""" + annotator = WorkspaceAnnotator(mod) + for name, f in mod.functions_items(): + if isinstance(f, relax.Function): + new_f = annotator.visit_expr(f) + mod.update_func(name, new_f) + return mod diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 0d9f4ff8e923..80979bbe7e25 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -383,19 +383,25 @@ def _check_stacked_attention(context: PatternCheckContext) -> bool: if not split_op.attrs.axis == 2: return False else: + get_const_int_list = lambda tup: [int(e.value) for e in tup] last_end = 0 for name in ["query", "key", "value"]: assert f"strided_slice_{name}" in context.annotated_expr strided_slice_op = context.annotated_expr[f"strided_slice_{name}"] - if list(strided_slice_op.attrs.axes) != [2]: + axes = get_const_int_list(strided_slice_op.args[1]) + begins = get_const_int_list(strided_slice_op.args[2]) + ends = get_const_int_list(strided_slice_op.args[3]) + strides = get_const_int_list(strided_slice_op.args[4]) + + if axes != [2]: return False - if list(strided_slice_op.attrs.begin) != [last_end]: + if begins != [last_end]: return False - if not len(strided_slice_op.attrs.end) == 1: + if not len(ends) == 1: return False - last_end = strided_slice_op.attrs.end[0] - if list(strided_slice_op.attrs.strides) != [1]: + if strides != [1]: return False + last_end = ends[0] return True @@ -537,7 +543,7 @@ def visit_function_(self, f): return new_f - if "attention" in f.attrs["Composite"]: + if "attention" in f.attrs["Composite"] and "cutlass" in f.attrs["Composite"]: # Workspace is needed only for larger head sizes, but for simplicity we always allocate. out_dtype = f.ret_struct_info.dtype out_size_1d = _shape_1d(f.ret_struct_info.shape) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 8ec43f1f27f6..1faef9cceb05 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -260,7 +260,7 @@ def make_attention_pattern(with_bias: bool = False, var_len: bool = False): return out, annotations -def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): +def make_stacked_attention_pattern(start_op: str, with_bias: bool = False, layout="BS3NH"): """ Create pattern for fused multi head attention with stacked input. @@ -272,6 +272,9 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): with_bias: bool Whether or not to include bias addition + layout: str + The layout of the stacked input tensor. + Returns ------- pattern: DFPattern @@ -290,17 +293,28 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): key_raw = is_tuple_get_item(qkv_tuple, 1) value_raw = is_tuple_get_item(qkv_tuple, 2) elif start_op == "strided_slice": - ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(stacked_qkv) - ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(stacked_qkv) - ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(stacked_qkv) + ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) + ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) + ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")( + stacked_qkv, varg_default_wildcard=True + ) else: raise NotImplementedError() query_reshape_list = wildcard() key_reshape_list = wildcard() value_reshape_list = wildcard() - query = is_op("relax.reshape")(query_raw, query_reshape_list) - key = is_op("relax.reshape")(key_raw, key_reshape_list) - value = is_op("relax.reshape")(value_raw, value_reshape_list) + if layout == "BS3NH": + query = is_op("relax.reshape")(query_raw, query_reshape_list) + key = is_op("relax.reshape")(key_raw, key_reshape_list) + value = is_op("relax.reshape")(value_raw, value_reshape_list) + elif layout == "SBN3H": + ops["q_transpose"] = query = is_op("relax.permute_dims")(query_raw) + ops["k_transpose"] = key = is_op("relax.permute_dims")(key_raw) + ops["v_transpose"] = value = is_op("relax.permute_dims")(value_raw) annotations = { "stacked_qkv": stacked_qkv, "query_reshape_list": query_reshape_list, @@ -314,6 +328,10 @@ def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): out = is_op("relax.nn.attention_bias")(query, key, value, bias) else: out = is_op("relax.nn.attention")(query, key, value) + + if layout == "SBN3H": + out = is_op("relax.permute_dims")(out) + return out, annotations diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 725a930fd680..ec072f663cd5 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1568,11 +1568,14 @@ def scaled_dot_product_attention( Parameters ---------- query : Tensor - Tensor representing current attention lookup. + Tensor representing current attention lookup of shape + [batch, seq_len, num_heads, head_size]. key : Tensor - Tensor representing cross attention mapping. + Tensor representing cross attention mapping of shape + [batch, seq_len_kv, num_heads_kv, head_size]. value : Tensor - Tensor representing embedded attention values. + Tensor representing embedded attention values of shape + [batch, seq_len_kv, num_heads_kv, head_size_value]. attn_mask : Optional[Tensor] Optional mask for attention, not yet supported. is_causal : Optional[bool] diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index 4256ebc3be89..dc43d6c1f8ee 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -21,3 +21,4 @@ from .relay_translator import * from .ast_printer import dump_ast from .matmul import * +from .attention import * diff --git a/python/tvm/relax/testing/attention.py b/python/tvm/relax/testing/attention.py new file mode 100644 index 000000000000..a00674394ba2 --- /dev/null +++ b/python/tvm/relax/testing/attention.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relax script for attention module.""" +import tvm +from tvm.script import relax as R, tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def get_relax_attention_module( + q_shape, + k_shape, + v_shape, + *, + dtype, + bias_shape=None, + qk_scale=None, + causal_mask=None, + window_size=None, +): # pylint: disable=too-many-arguments, too-many-locals, invalid-name + """Get a relax module for attention.""" + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + if window_size is not None: + window_size = T.IntImm("int32", window_size) + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + q = R.arg("q", R.Tensor(q_shape, dtype)) + k = R.arg("k", R.Tensor(k_shape, dtype)) + v = R.arg("v", R.Tensor(v_shape, dtype)) + bias = None + if bias_shape is not None and bias_shape != "none": + bias = R.arg("bias", R.Tensor(bias_shape, dtype)) + + with R.dataflow() as frame: + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask, window_size)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_relax_stacked_attention_module( + qkv, + b, + s, + n, + h, + h_v, + op, + bias=None, + qk_scale=None, + single_shape=False, + layout="BS3NH", +): # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, invalid-name + # pylint: disable=too-many-statements + """Get a relax module for stacked attention.""" + dtype = str(qkv.dtype) + assert layout in ["BS3NH", "SBN3H"] + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) + + if single_shape: + if layout == "BS3NH": + qk_shape = R.shape([b, s, n, h]) + elif layout == "SBN3H": + qk_shape = R.shape([b, s, n, h]) + v_shape = qk_shape + else: + if layout == "BS3NH": + qk_shape = [b, s, n, h] + v_shape = [b, s, n, h_v] + elif layout == "SBN3H": + qk_shape = [s, b, n, h] + v_shape = [s, b, n, h_v] + + if layout == "BS3NH": + split_axis = 2 + split_sections = [n * h, n * h * 2] + elif layout == "SBN3H": + split_axis = 3 + split_sections = [h, h * 2] + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) + if bias is not None: + bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + with R.dataflow() as frame: + if op == "split": + qkv_tuple = R.split(qkv, split_sections, axis=split_axis) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + elif op == "strided_slice": + q = R.strided_slice(qkv, [split_axis], [0], [split_sections[0]], [1]) + k = R.strided_slice( + qkv, [split_axis], [split_sections[0]], [split_sections[1]], [1] + ) + v = R.strided_slice( + qkv, + [split_axis], + [split_sections[1]], + [int(qkv.struct_info.shape[split_axis])], + [1], + ) + else: + raise NotImplementedError() + if layout == "BS3NH": + q = R.reshape(q, qk_shape) + k = R.reshape(k, qk_shape) + v = R.reshape(v, v_shape) + elif layout == "SBN3H": + q = R.permute_dims(q, [1, 0, 2, 3]) + k = R.permute_dims(k, [1, 0, 2, 3]) + v = R.permute_dims(v, [1, 0, 2, 3]) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) + if layout == "SBN3H": + result = R.emit(R.permute_dims(result, [1, 0, 2, 3])) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 72a7cedc491c..1486e9986e0e 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -84,3 +84,4 @@ from .searchsorted import searchsorted_ref from .conv2d_backcward_weight_python import conv2d_backward_weight_python from .lstm_python import lstm_python +from .attention_python import attention_python diff --git a/python/tvm/topi/testing/attention_python.py b/python/tvm/topi/testing/attention_python.py new file mode 100644 index 000000000000..856667aeddd1 --- /dev/null +++ b/python/tvm/topi/testing/attention_python.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Attention operator in python""" +from typing import Optional +import numpy as np +from .softmax_python import softmax_python + + +def attention_python( + q: np.ndarray, + k: np.ndarray, + v: np.ndarray, + bias: Optional[np.ndarray], + qk_scale: float, + causal: str, + window_size: Optional[int] = None, + layout: str = "BSNH", +): # pylint: disable=too-many-arguments, too-many-locals, invalid-name + """Attention operator in python + + Parameters + ---------- + q : np.ndarray + Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by + `layout`. + k : np.ndarray + Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the layout specified + by `layout`. + v : np.ndarray + Value tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim_v] in the layout + specified by `layout`. + bias : np.ndarray + Bias tensor with shape [batch, num_heads, seq_length, seq_length] + qk_scale : float + Scale factor for the query-key product. + causal : str + The type of causal mask to apply. Can be "none", "TopLeft", or "BottomRight". + window_size : Optional[int] + The window size for the causal mask. + layout : str + The layout of the input tensors, e.g. "BSNH" or "BNSH". + + Returns + ------- + np.ndarray + The output tensor with shape [batch, seq_length, num_heads, head_dim_v] in the layout + specified by `layout`. + """ + assert layout in ["BSNH", "BNSH", "SBNH"] + + dim_b = layout.find("B") + dim_s = layout.find("S") + dim_n = layout.find("N") + dim_h = layout.find("H") + + q = q.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s, h + k = k.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s_kv, h + kt = k.transpose(0, 1, 3, 2) # b, n, h, s_kv + v = v.transpose(dim_b, dim_n, dim_s, dim_h) + + num_heads = q.shape[1] + num_kv_heads = k.shape[1] + s = q.shape[2] + s_kv = k.shape[2] + + if num_heads != num_kv_heads: + assert num_heads % num_kv_heads == 0 + factor = num_heads // num_kv_heads + kt = np.repeat(kt, factor, axis=1) + v = np.repeat(v, factor, axis=1) + + if not qk_scale == "none": + score = q @ kt * qk_scale # b, n, s, s_kv + else: + score = q @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + if bias is not None: + score = score + bias # b, n, s, s_kv + if causal == "none": + attn = softmax_python(score, -1) + else: + if causal == "TopLeft": + offset = 0 + elif causal == "BottomRight": + offset = abs(s - s_kv) + else: + raise ValueError(f"Unsupported causal type: {causal}") + score_masked = np.tril(score, k=offset) + + if window_size: + score_masked = np.triu( + score_masked, -window_size + 1 # pylint: disable=invalid-unary-operand-type + ) + + score_masked_exp = np.tril( + np.exp(score_masked - np.max(score_masked, axis=-1, keepdims=True)), k=offset + ) + + if window_size: + score_masked_exp = np.triu( + score_masked_exp, -window_size + 1 # pylint: disable=invalid-unary-operand-type + ) + + score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True) + attn = np.divide(score_masked_exp, score_masked_sum) + + out = attn @ v # b, n, s, h_v + return out.transpose(*np.argsort([dim_b, dim_n, dim_s, dim_h]).tolist()) diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index 812016b8eafa..d8ca5f4e97f4 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -55,6 +55,17 @@ class cuDNNJSONSerializer : public JSONSerializer { std::string composite_name = composite_opt.value(); + if (composite_name.find("cudnn.conv2d") != std::string::npos) { + return HandleConv2D(call_node, fn, composite_name); + } else if (composite_name.find("cudnn.attention") != std::string::npos) { + return HandleAttention(call_node, fn, composite_name); + } else { + LOG(FATAL) << "Unsupported composite function: " << composite_name; + } + } + + NodeEntries HandleConv2D(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { NodeEntries inputs_tmp; for (const auto& arg : call_node->args) { auto res = VisitExpr(arg); @@ -80,6 +91,42 @@ class cuDNNJSONSerializer : public JSONSerializer { return AddNode(node, GetRef(call_node)); } + NodeEntries HandleAttention(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + std::string layout = composite_name.substr(composite_name.find_last_of(".") + 1); + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + ICHECK_EQ(inputs.size(), 2); + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.nn.attention"); + auto q_shape = Downcast( + Downcast(root_call->args[0]->struct_info_.value())->shape.value()); + auto k_shape = Downcast( + Downcast(root_call->args[1]->struct_info_.value())->shape.value()); + auto v_shape = Downcast( + Downcast(root_call->args[2]->struct_info_.value())->shape.value()); + int num_heads = q_shape->values[2].as()->value; + int num_kv_heads = k_shape->values[2].as()->value; + int head_size = q_shape->values[3].as()->value; + int head_size_v = v_shape->values[3].as()->value; + SetCallNodeAttribute(node, root_call); + + auto to_str_array = [](int val) { + return std::vector{std::vector{std::to_string(val)}}; + }; + node->SetAttr("num_heads", to_str_array(num_heads)); + node->SetAttr("num_kv_heads", to_str_array(num_kv_heads)); + node->SetAttr("head_size", to_str_array(head_size)); + node->SetAttr("head_size_v", to_str_array(head_size_v)); + node->SetAttr("layout", std::vector{std::vector{layout}}); + return AddNode(node, GetRef(call_node)); + } + private: /*! \brief The bindings to look up composite functions. */ Map bindings_; diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 1d4a0177126a..05aa8ce5528d 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -66,8 +66,10 @@ class ExternFunctionRewriter : ExprMutator { } new_params.push_back(workspace_param); + auto new_attrs = func_node->attrs; + new_attrs.CopyOnWrite()->dict.erase(attr::kWorkspaceSize); return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info, - func_node->is_pure, func_node->attrs); + func_node->is_pure, new_attrs); } return ExprMutator::VisitExpr_(func_node); } @@ -122,6 +124,7 @@ class WorkspaceProvider : ExprMutator { builder_->UpdateFunction(new_gvar, WithAttr(f, tvm::attr::kGlobalSymbol, new_gvar->name_hint)); gvar_map_[gvar] = new_gvar; + new_gvars_.insert(new_gvar); builder_->GetContextIRModule()->Remove(GetRef(gvar)); } @@ -164,8 +167,7 @@ class WorkspaceProvider : ExprMutator { auto new_op = VisitExpr(call_node->op); if (auto gv = new_op.as()) { - auto callee = builder_->GetContextIRModule()->Lookup(gv.value()); - if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) { + if (new_gvars_.count(gv.value())) { auto new_args = call_node->args; ICHECK(workspace_var_main_.defined()); new_args.push_back(workspace_var_main_); @@ -185,6 +187,7 @@ class WorkspaceProvider : ExprMutator { * the new ones that are transformed to take an additional workspace parameter. This is only * needed since the struct info of the global variables changes between transformation. */ std::unordered_map gvar_map_; + std::unordered_set new_gvars_; }; } // namespace relax diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 2be7ad41f3e1..6030a28d93b6 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -595,8 +596,7 @@ class FunctionCreator : public ExprMutator { } StructInfo param_sinfo = GetStructInfo(expr); - // Exclude PrimValues from arg/params to make composite functions contain PrimValues. - if (!expr->IsInstance()) { + if (!IsInlinableConstants(expr)) { Var param(std::move(name), GetStructInfo(expr)); arguments_.push_back(expr); params_.push_back(param); @@ -621,6 +621,21 @@ class FunctionCreator : public ExprMutator { return ExprMutator::VisitExpr(expr); } + // Check if the expression is constant PrimValue or ShapeExpr or tuple of them that can be + // inlined in the composite functions and excluded from args/params. + bool IsInlinableConstants(const Expr& expr) { + if (const auto* tuple = expr.as()) { + return std::all_of(tuple->fields.begin(), tuple->fields.end(), + [this](const Expr& e) { return IsInlinableConstants(e); }); + } else if (const auto* prim_value = expr.as()) { + return tvm::tir::UndefinedVars(prim_value->value).empty(); + } else if (const auto* shape_expr = expr.as()) { + return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), + [this](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); + } + return false; + } + private: /*! \brief The variables defined in this function */ std::unordered_set defined_vars_; diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc new file mode 100644 index 000000000000..f8b170fe2052 --- /dev/null +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.cc + * \brief cuDNN scale dot product attention implementation + */ + +#include "./attention.h" + +#include +#include + +#include "../../../cuda/cuda_common.h" +#include "../cudnn_utils.h" + +namespace tvm { +namespace contrib { + +void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, int64_t head_size_v, + double scale, const DLDataType& data_type, + const std::string& layout) { + graph_ = std::make_unique(); + + CHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) + << "Only float16 is supported"; + + graph_->set_io_data_type(cudnn_frontend::DataType_t::HALF) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto q_desc = cudnn_frontend::graph::Tensor_attributes().set_name("Q").set_uid(kTensorIDQ); + auto k_desc = cudnn_frontend::graph::Tensor_attributes().set_name("K").set_uid(kTensorIDK); + auto v_desc = cudnn_frontend::graph::Tensor_attributes().set_name("V").set_uid(kTensorIDV); + auto o_desc = cudnn_frontend::graph::Tensor_attributes().set_name("Out").set_uid(kTensorIDOut); + + std::vector q_stride, k_stride, v_stride, + o_stride; // stride in the order of (batch, num_heads, seq_len, head_size) + + if (layout == "BS3NH") { + int64_t stride_H = 1; + int64_t q_stride_N = head_size; + int64_t k_stride_N = head_size; + int64_t v_stride_N = head_size_v; + int64_t stride_S = + num_heads * q_stride_N + num_kv_heads * k_stride_N + num_kv_heads * v_stride_N; + int64_t stride_B = stride_S * seq_len; + q_stride = {stride_B, q_stride_N, stride_S, stride_H}; + k_stride = {stride_B, k_stride_N, stride_S, stride_H}; + v_stride = {stride_B, v_stride_N, stride_S, stride_H}; + o_stride = {seq_len * num_heads * head_size_v, head_size_v, num_heads * head_size_v, 1}; + offset_k_ = num_heads * head_size; + offset_v_ = offset_k_ + num_kv_heads * head_size; + } else if (layout == "SBN3H") { + CHECK_EQ(num_kv_heads, num_heads); + int64_t stride_H = 1; + int64_t stride_N = head_size + head_size + head_size_v; + int64_t stride_B = num_heads * stride_N; + int64_t stride_S = stride_B * batch; + q_stride = k_stride = v_stride = {stride_B, stride_N, stride_S, stride_H}; + o_stride = {num_heads * head_size_v, head_size_v, num_heads * head_size_v * batch, 1}; + offset_k_ = head_size; + offset_v_ = offset_k_ * 2; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + + q_desc = q_desc.set_dim({batch, num_heads, seq_len, head_size}).set_stride(q_stride); + k_desc = k_desc.set_dim({batch, num_kv_heads, seq_len, head_size}).set_stride(k_stride); + v_desc = v_desc.set_dim({batch, num_kv_heads, seq_len, head_size_v}).set_stride(v_stride); + auto sdpa_options = cudnn_frontend::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(true) + .set_alibi_mask(false) + .set_causal_mask(false) + .set_attn_scale(scale); + + auto q = graph_->tensor(q_desc); + auto k = graph_->tensor(k_desc); + auto v = graph_->tensor(v_desc); + auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options); + CHECK(stats == nullptr); + o->set_output(true).set_dim({batch, num_heads, seq_len, head_size_v}).set_stride(o_stride); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, {cudnn_frontend::HeurMode_t::A})); +} + +void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out) { + CUDNN_CALL( + cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle, tvm::runtime::GetCUDAStream())); + auto* qkv_base = reinterpret_cast(qkv->data) + qkv->byte_offset; + auto* q_ptr = reinterpret_cast(qkv_base) + offset_q_; + auto* k_ptr = reinterpret_cast(qkv_base) + offset_k_; + auto* v_ptr = reinterpret_cast(qkv_base) + offset_v_; + auto* out_ptr = reinterpret_cast(out->data) + out->byte_offset; + + size_t workspace_size = graph_->get_workspace_size(); + CHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small"; + std::unordered_map inputs = { + {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, {kTensorIDOut, out_ptr}}; + + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs, workspace->data)); +} + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h new file mode 100644 index 000000000000..4d0309fb3ba6 --- /dev/null +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.h + * \brief cuDNN scale dot product attention implementation + */ + +#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ +#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ + +#include +#include + +#include +#include + +#define CUDNN_FRONTEND_CALL(func) \ + do { \ + auto status = (func); \ + CHECK(status.is_good()) << status.get_message(); \ + } while (0) + +namespace tvm { +namespace contrib { + +class CuDNNSDPARunnerNode : public tvm::runtime::Object { + public: + CuDNNSDPARunnerNode() {} + + ~CuDNNSDPARunnerNode() {} + + static constexpr const char* _type_key = "contrib.cudnn.SDPARunner"; + + void Init(int64_t batch, int64_t seq_len, int64_t num_heads, int64_t num_kv_heads, + int64_t head_size, int64_t head_size_v, double scale, const DLDataType& data_type, + const std::string& layout); + + void Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out); + + static constexpr int kTensorIDQ = 0; + static constexpr int kTensorIDK = 1; + static constexpr int kTensorIDV = 2; + static constexpr int kTensorIDOut = 4; + + private: + std::unique_ptr graph_{nullptr}; + int64_t offset_q_{0}; + int64_t offset_k_{0}; + int64_t offset_v_{0}; +}; + +class CuDNNSDPARunner : public tvm::runtime::ObjectRef { + public: + static CuDNNSDPARunner Create() { + auto n = make_object(); + return CuDNNSDPARunner(n); + } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CuDNNSDPARunner, tvm::runtime::ObjectRef, + CuDNNSDPARunnerNode); +}; + +} // namespace contrib +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 7d701396d0ca..3f4b659275d4 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -31,6 +31,10 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" + +#ifdef TVM_USE_CUDNN_FRONTEND +#include "./cudnn_frontend/attention.h" +#endif #include "cudnn_utils.h" namespace tvm { @@ -47,78 +51,19 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { : JSONRuntimeBase(symbol_name, graph_json, const_names) {} void Init(const Array& consts) override { - auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(); - auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); - ICHECK(func != nullptr); - stream = static_cast((*func)().operator void*()); - - auto attr_in_name = [](const std::string& op_name, const std::string& attr_name) { - return op_name.find(attr_name) != std::string::npos; - }; - - auto vstr2vint = [](const JSONGraphNode& node, const std::string& attrStr) { - auto string_to_int = [](const std::string& str) { return std::stoi(str); }; - auto string_vec = node.GetAttr>(attrStr); - std::vector int_vec(string_vec.size()); - std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), string_to_int); - return int_vec; - }; + op_execs_.resize(nodes_.size()); // get some config from the graph for (size_t i = 0; i < nodes_.size(); ++i) { const auto& node = nodes_[i]; if (node.GetOpType() == "kernel") { - op_name = node.GetOpName(); - std::vector input_dims, kernel_dims, output_dims; - auto input_node = nodes_[0]; - auto input_shapes = input_node.GetOpShape()[0]; - auto kernel_node = nodes_[1]; - auto kernel_shapes = kernel_node.GetOpShape()[0]; - auto output_shapes = node.GetOpShape()[0]; - for (const auto& _i : input_shapes) { - input_dims.emplace_back(static_cast(_i)); - } - for (const auto& _i : kernel_shapes) { - kernel_dims.emplace_back(static_cast(_i)); + std::string op_name = node.GetOpName(); + if (op_name.find("conv2d") != std::string::npos) { + op_execs_[i] = GetConv2DExec(node); + } else if (op_name.find("attention") != std::string::npos) { + op_execs_[i] = GetAttentionExec(node); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; } - for (const auto& _i : output_shapes) { - output_dims.emplace_back(static_cast(_i)); - } - has_bias = attr_in_name(op_name, "bias"); - groups = std::stoi(node.GetAttr>("groups")[0]); - padding = vstr2vint(node, "padding"); - strides = vstr2vint(node, "strides"); - dilation = vstr2vint(node, "dilation"); - conv_dtype = node.GetAttr>("out_dtype")[0]; - std::string layout = node.GetAttr>("out_layout")[0]; - dims = layout.size() - 2; // remove O and I dims - - if (layout == "NCHW") - format = CUDNN_TENSOR_NCHW; - else if (layout == "NHWC") - format = CUDNN_TENSOR_NHWC; - else - LOG(FATAL) << "Unsupported layout: " << layout; - - if (attr_in_name(op_name, "relu")) { - act = CUDNN_ACTIVATION_RELU; - } else if (attr_in_name(op_name, "relu6")) { - act = CUDNN_ACTIVATION_CLIPPED_RELU; - coef = 6.0; - } else if (attr_in_name(op_name, "leaky_relu")) { - act = CUDNN_ACTIVATION_RELU; - coef = 0.1; - } - this->handle = entry_ptr->handle; - this->kernel_node = node; - - // find best algo - TVMRetValue best_algo; - - tvm::contrib::FindAlgo(format, dims, groups, padding.data(), strides.data(), - dilation.data(), input_dims.data(), kernel_dims.data(), - output_dims.data(), conv_dtype, conv_dtype, false, &best_algo); - - this->algo = best_algo.operator int(); } } } @@ -126,27 +71,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { const char* type_key() const override { return "cudnn_json"; } // May be overridden void Run() override { - auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { - const DLTensor* bias = nullptr; - if (has_bias) { - bias = GetInput(node, 2); + for (const auto& f : op_execs_) { + if (f != nullptr) { + f(); } - return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); - }; - - auto [a_ptr, b_ptr, bias_ptr] = get_inputs(kernel_node, has_bias); - uint32_t output_eid = EntryID(outputs_[0]); - auto out_ptr = data_entry_[output_eid]; - - if (this->has_bias) { - tvm::contrib::ConvolutionBiasActivationForward( - this->mode, this->format, this->algo, this->dims, this->groups, this->act, this->coef, - this->padding.data(), this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr, - bias_ptr, this->conv_dtype); - } else { - tvm::contrib::ConvolutionForward( - this->mode, this->format, this->algo, this->dims, this->groups, this->padding.data(), - this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr, this->conv_dtype); } } @@ -157,27 +85,150 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { ICHECK(eid < data_entry_.size()); return data_entry_[eid]; } - /*conv op name*/ - std::string op_name; - /*conv mode: CUDNN_CROSS_CORRELATION by default*/ - int mode = CUDNN_CROSS_CORRELATION; - /*algo: by default we select the implicit gemm algo, will be tuned in the initial pass.*/ - int algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - /*if has bias*/ - bool has_bias = false; - /*args for function call*/ - int act = CUDNN_ACTIVATION_IDENTITY; - double coef = 1.0; - int format = CUDNN_TENSOR_NHWC; - int dims = 2; - int groups = 1; - std::vector padding; - std::vector strides; - std::vector dilation; - std::string conv_dtype; - cudaStream_t stream; - cudnnHandle_t handle; - tvm::runtime::json::JSONGraphNode kernel_node; + + bool attr_in_name(const std::string& op_name, const std::string& attr_name) { + return op_name.find(attr_name) != std::string::npos; + } + + std::vector vstr2vint(const JSONGraphNode& node, const std::string& attrStr) { + auto string_to_int = [](const std::string& str) { return std::stoi(str); }; + auto string_vec = node.GetAttr>(attrStr); + std::vector int_vec(string_vec.size()); + std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), string_to_int); + return int_vec; + } + + std::function GetConv2DExec(const JSONGraphNode& node) { + auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(); + auto op_name = node.GetOpName(); + + std::vector input_dims, kernel_dims, output_dims; + auto input_node = nodes_[0]; + auto input_shapes = input_node.GetOpShape()[0]; + auto kernel_shapes = nodes_[1].GetOpShape()[0]; + auto output_shapes = node.GetOpShape()[0]; + for (const auto& _i : input_shapes) { + input_dims.emplace_back(static_cast(_i)); + } + for (const auto& _i : kernel_shapes) { + kernel_dims.emplace_back(static_cast(_i)); + } + for (const auto& _i : output_shapes) { + output_dims.emplace_back(static_cast(_i)); + } + bool has_bias = attr_in_name(op_name, "bias"); + int groups = std::stoi(node.GetAttr>("groups")[0]); + std::vector padding = vstr2vint(node, "padding"); + std::vector strides = vstr2vint(node, "strides"); + std::vector dilation = vstr2vint(node, "dilation"); + auto conv_dtype = node.GetAttr>("out_dtype")[0]; + std::string layout = node.GetAttr>("out_layout")[0]; + int dims = layout.size() - 2; // remove O and I dims + + int format = CUDNN_TENSOR_NHWC; + if (layout == "NCHW") { + format = CUDNN_TENSOR_NCHW; + } else if (layout == "NHWC") { + format = CUDNN_TENSOR_NHWC; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + + int act = CUDNN_ACTIVATION_IDENTITY; + double coef = 1.0; + if (attr_in_name(op_name, "relu")) { + act = CUDNN_ACTIVATION_RELU; + } else if (attr_in_name(op_name, "relu6")) { + act = CUDNN_ACTIVATION_CLIPPED_RELU; + coef = 6.0; + } else if (attr_in_name(op_name, "leaky_relu")) { + act = CUDNN_ACTIVATION_RELU; + coef = 0.1; + } + + /*conv mode: CUDNN_CROSS_CORRELATION by default*/ + int mode = CUDNN_CROSS_CORRELATION; + + // find best algo + TVMRetValue best_algo; + + tvm::contrib::FindAlgo(format, dims, groups, padding.data(), strides.data(), dilation.data(), + input_dims.data(), kernel_dims.data(), output_dims.data(), conv_dtype, + conv_dtype, false, &best_algo); + + int algo = best_algo.operator int(); + std::function op_exec = [=]() { + auto stream = static_cast(GetCUDAStream()); + CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); + + auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = GetInput(node, 2); + } + return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias); + }; + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, has_bias); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = data_entry_[output_eid]; + if (has_bias) { + tvm::contrib::ConvolutionBiasActivationForward( + mode, format, algo, dims, groups, act, coef, padding.data(), strides.data(), + dilation.data(), a_ptr, b_ptr, out_ptr, bias_ptr, conv_dtype); + } else { + tvm::contrib::ConvolutionForward(mode, format, algo, dims, groups, padding.data(), + strides.data(), dilation.data(), a_ptr, b_ptr, out_ptr, + conv_dtype); + } + }; + return op_exec; + } + + std::function GetAttentionExec(const JSONGraphNode& node) { +#ifdef TVM_USE_CUDNN_FRONTEND + auto dtype = node.GetOpDataType()[0]; + int num_heads = vstr2vint(node, "num_heads")[0]; + int num_kv_heads = vstr2vint(node, "num_kv_heads")[0]; + int head_size = vstr2vint(node, "head_size")[0]; + int head_size_v = vstr2vint(node, "head_size_v")[0]; + std::string layout = node.GetAttr>("layout")[0]; + const auto& input_qkv_node = nodes_[EntryID(node.GetInputs()[0])]; + auto qkv_shapes = input_qkv_node.GetOpShape()[0]; + + int64_t batch, seq_len; + if (layout == "BS3NH") { + ICHECK_EQ(qkv_shapes.size(), 3); + batch = qkv_shapes[0]; + seq_len = qkv_shapes[1]; + } else if (layout == "SBN3H") { + ICHECK_EQ(qkv_shapes.size(), 4); + batch = qkv_shapes[1]; + seq_len = qkv_shapes[0]; + } else { + LOG(FATAL) << "Unsupported layout: " << layout; + } + double scale = 1 / std::sqrt(head_size); + std::string scale_attr = node.GetAttr>("scale")[0]; + if (scale_attr.size()) { + scale = std::stod(scale_attr); + } + + auto runner = tvm::contrib::CuDNNSDPARunner::Create(); + runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size, head_size_v, scale, dtype, + layout); + return [=]() { + auto qkv = GetInput(node, 0); + auto workspace = const_cast(GetInput(node, 1)); + auto out = const_cast(data_entry_[EntryID(outputs_[0])]); + runner->Run(qkv, workspace, out); + }; +#else + LOG(FATAL) << "Please build with CUDNN frontend to use attention op"; +#endif + } + + std::vector> op_execs_; }; runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 0f911905f820..59f49bfde889 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -22,7 +22,8 @@ import tvm.topi.testing from tvm import relax from tvm.relax.backend.contrib.cudnn import partition_for_cudnn -from tvm.relax.testing import get_relax_matmul_module +from tvm.relax.testing import get_relax_matmul_module, get_relax_stacked_attention_module +from tvm.contrib.pickle_memoize import memoize from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder @@ -99,7 +100,7 @@ def get_relax_conv2d_module( def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False): mod = partition_for_cudnn(mod) mod = relax.transform.RunCodegen()(mod) - return build_and_run(mod, np_inputs, "cuda", cuda_graph) + return build_and_run(mod, np_inputs, "cuda", cuda_graph=cuda_graph) def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): @@ -244,5 +245,65 @@ def test_conv2d_nchw_oihw_offload(data_shape, weight_shape, dtype, with_bias, ac tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@memoize("topi.tests.test_codegen_cudnn.test_stacked_attention_offload") +def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, dtype, layout): + if layout == "BS3NH": + qkv = np.random.randn(b, s, n * h * 2 + n * h_v).astype(dtype) + split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2) + q = split_qkv[0].reshape(b, s, n, h) + k = split_qkv[1].reshape(b, s, n, h) + v = split_qkv[2].reshape(b, s, n, h_v) + layout = "BSNH" + elif layout == "SBN3H": + qkv = np.random.randn(s, b, n, h * 2 + h_v).astype(dtype) + q, k, v = np.split(qkv, [h, h * 2], axis=3) + layout = "SBNH" + else: + raise ValueError("Unsupported layout: {}".format(layout)) + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias # b, n, s, s + else: + bias = None + ref = tvm.topi.testing.attention_python(q, k, v, bias, qk_scale, "none", None, layout) + return qkv, bias, ref + + +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape scale, single_shape, layout + (4, 8, 32, (64, 32), "none", 1.0, False, "BS3NH"), + (4, 8, 32, (64, 64), "none", "none", True, "BS3NH"), + (4, 8, 32, (64, 32), "none", 1.0, False, "SBN3H"), + (4, 8, 32, (64, 64), "none", "none", True, "SBN3H"), + ] +) +def stacked_attention_size(request): + return request.param + + +@pytest.mark.skip(reason="require cudnn frontend") +def test_stacked_attention_split_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, scale, single_shape, layout = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, scale, "float16", layout + ) + if scale == "none": + mod = get_relax_stacked_attention_module( + qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape, layout=layout + ) + scale = 1.0 / np.sqrt(h) + else: + mod = get_relax_stacked_attention_module( + qkv, b, s, n, h, h_v, "split", bias, scale, single_shape=single_shape, layout=layout + ) + + if bias is None: + out = get_result_with_relax_cudnn_offload(mod, [qkv]) + else: + out = get_result_with_relax_cudnn_offload(mod, [qkv, bias]) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=2e-2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 969651f72fd4..3fa3f2d914d7 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -24,7 +24,11 @@ from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend.contrib.cutlass import partition_for_cutlass -from tvm.relax.testing import get_relax_matmul_module +from tvm.relax.testing import ( + get_relax_matmul_module, + get_relax_attention_module, + get_relax_stacked_attention_module, +) from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T @@ -594,47 +598,6 @@ def attention_size(request): return request.param -def get_relax_attention_module( - q_shape, - k_shape, - v_shape, - *, - dtype, - bias_shape=None, - qk_scale=None, - causal_mask=None, - window_size=None, -): - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder - from tvm.script.ir_builder import tir as T - - if qk_scale is not None: - qk_scale = T.FloatImm("float32", qk_scale) - - if window_size is not None: - window_size = T.IntImm("int32", window_size) - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - q = R.arg("q", R.Tensor(q_shape, dtype)) - k = R.arg("k", R.Tensor(k_shape, dtype)) - v = R.arg("v", R.Tensor(v_shape, dtype)) - bias = None - if bias_shape is not None and bias_shape != "none": - bias = R.arg("bias", R.Tensor(bias_shape, dtype)) - - with R.dataflow() as frame: - result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask, window_size)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) - - def get_numpy_attention_ref( b, s, @@ -649,59 +612,20 @@ def get_numpy_attention_ref( window_size=None, num_kv_head=None, ): - if num_kv_head is None: - num_kv_head = n - + num_kv_head = num_kv_head or n q = np.random.randn(b, s, n, h).astype(dtype) - k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) - v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) - - if num_kv_head is None: - k = k_orig - v = v_orig - else: - factor = n // num_kv_head - k = np.repeat(k_orig, factor, axis=2) - v = np.repeat(v_orig, factor, axis=2) - - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - if not qk_scale == "none": - score = qt @ kt * qk_scale # b, n, s, s_kv - else: - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - if not bias_shape == "none": - bias = np.random.randn(*bias_shape).astype(dtype) - score = score + bias # b, n, s, s_kv - else: + k = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) + v = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) + if bias_shape == "none": bias = None - if causal == "none": - attn = tvm.topi.testing.softmax_python(score, -1) else: - if causal == "TopLeft": - offset = 0 - elif causal == "BottomRight": - offset = abs(s - s_kv) - else: - raise NotImplementedError() - score_masked = np.tril(score, k=offset) - - if window_size: - score_masked = np.triu(score_masked, -window_size + 1) - - score_masked_exp = np.tril( - np.exp(score_masked - np.max(score_masked, axis=-1, keepdims=True)), k=offset - ) - - if window_size: - score_masked_exp = np.triu(score_masked_exp, -window_size + 1) + bias = np.random.randn(*bias_shape).astype(dtype) - score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True) - attn = np.divide(score_masked_exp, score_masked_sum) + ref = tvm.topi.testing.attention_python( + q, k, v, bias, qk_scale, causal=causal, window_size=window_size, layout="BSNH" + ) - vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v - ref = attn @ vt # b, n, s, h_v - return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + return q, k, v, bias, ref def test_attention_offload(attention_size, attention_dtype): @@ -844,69 +768,14 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, dtype q = np.reshape(split_qkv[0], (b, s, n, h)) k = np.reshape(split_qkv[1], (b, s, n, h)) v = np.reshape(split_qkv[2], (b, s, n, h_v)) - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s - if not qk_scale == "none": - score = qt @ kt * qk_scale # b, n, s, s - else: - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s if not bias_shape == "none": bias = np.random.randn(*bias_shape).astype(dtype) - score = score + bias # b, n, s, s else: bias = None - attn = tvm.topi.testing.softmax_python(score, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v - ref = attn @ vt # b, n, s, h_v - return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v - - -def get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False -): - dtype = str(qkv.dtype) - - from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder - from tvm.script.ir_builder import tir as T - - if qk_scale is not None: - qk_scale = T.FloatImm("float32", qk_scale) - - if single_shape: - qk_shape = R.shape([b, s, n, h]) - v_shape = qk_shape - else: - qk_shape = [b, s, n, h] - v_shape = [b, s, n, h_v] - - with IRBuilder() as builder: - with relax_builder.function(): - R.func_name("main") - qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) - if bias is not None: - bias = R.arg("bias", R.Tensor(bias.shape, dtype)) - with R.dataflow() as frame: - if op == "split": - qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) - q = R.reshape(qkv_tuple[0], qk_shape) - k = R.reshape(qkv_tuple[1], qk_shape) - v = R.reshape(qkv_tuple[2], v_shape) - elif op == "strided_slice": - q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), qk_shape) - k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), qk_shape) - v = R.reshape( - R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]), v_shape - ) - else: - raise NotImplementedError() - result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) - R.output(result) - - R.func_ret_value(frame.output_vars[0]) - - func = builder.get() - return tvm.IRModule({"main": func}) + ref = tvm.topi.testing.attention_python( + q, k, v, bias, qk_scale, causal="none", window_size=None, layout="BSNH" + ) + return qkv, bias, ref @pytest.fixture( @@ -926,11 +795,30 @@ def test_stacked_attention_split_offload(stacked_attention_size): qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16") if scale == "none": mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "split", + bias, + single_shape=single_shape, + layout="BS3NH", ) else: mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "split", bias, scale, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "split", + bias, + scale, + single_shape=single_shape, + layout="BS3NH", ) if bias is None: @@ -945,11 +833,30 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size): qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32") if scale == "none": mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "strided_slice", bias, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "strided_slice", + bias, + single_shape=single_shape, + layout="BS3NH", ) else: mod = get_relax_stacked_attention_module( - qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape + qkv, + b, + s, + n, + h, + h_v, + "strided_slice", + bias, + scale, + single_shape=single_shape, + layout="BS3NH", ) if bias is None: out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2) diff --git a/tests/python/relax/test_transform_allocate_workspace.py b/tests/python/relax/test_transform_allocate_workspace.py index 1198642d3f35..248d195d654b 100644 --- a/tests/python/relax/test_transform_allocate_workspace.py +++ b/tests/python/relax/test_transform_allocate_workspace.py @@ -95,7 +95,6 @@ def fused_relax_nn_attention_cutlass1( R.func_attr( { "Codegen": "cutlass", - "WorkspaceSize": 65536, "global_symbol": "fused_relax_nn_attention_cutlass1", } ) @@ -107,7 +106,7 @@ def gv( v_1: R.Tensor((32, 8, 16, 8), dtype="float16"), workspace_1: R.Tensor((65536,), dtype="uint8"), ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): - R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536}) + R.func_attr({"Composite": "cutlass.attention", "Primitive": 1}) with R.dataflow(): gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = R.nn.attention( q_1, k_1, v_1, scale=None diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 6a36314a7444..cff832a21ff9 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -1053,7 +1053,6 @@ class Expected: @R.function def fused_relax_reshape_relax_matmul_tensorrt( inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), - param_0: R.Shape([1, 784]), lv1: R.Tensor((784, 512), dtype="float32"), ) -> R.Tensor((1, 512), dtype="float32"): R.func_attr({"Codegen": "tensorrt"}) @@ -1069,7 +1068,7 @@ def lv_1( R.output(gv) return gv - lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0) + lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, R.shape([1, 784])) @R.function def lv1_1_1( @@ -1100,7 +1099,7 @@ def main( ) gv: R.Tensor( (1, 512), dtype="float32" - ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, R.shape([1, 784]), lv1) + ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1) R.output(gv) return gv From 929b8f49ac73db3c6c7430bc1a414d4210e1aae5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 23 Jul 2024 06:28:04 +0900 Subject: [PATCH 013/202] [Relax][PyTorch] Add support for torch.permute (#17184) * add testcase * support torch.permute --- python/tvm/relax/frontend/torch/fx_translator.py | 4 ++++ tests/python/relax/test_frontend_from_fx.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5ed0f18deb9e..f9a5d9c33f02 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -550,7 +550,11 @@ def _flatten(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(x, new_shape)) def _permute(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) def _reshape(self, node: fx.node.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dd2719f8ce91..46c079aa99cc 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3029,10 +3029,14 @@ def forward(self, x): def test_permute(): input_info = [([1, 2, 3, 4], "float32")] - class Permute(Module): + class Permute1(Module): def forward(self, x): return x.permute(0, 3, 2, 1) + class Permute2(Module): + def forward(self, x): + return torch.permute(x, (0, 3, 2, 1)) + @tvm.script.ir_module class expected1: @R.function @@ -3046,7 +3050,8 @@ def main( R.output(gv) return gv - verify_model(Permute(), input_info, {}, expected1) + verify_model(Permute1(), input_info, {}, expected1) + verify_model(Permute2(), input_info, {}, expected1) def test_reshape(): From 91e9c63b42fcccec196a8ef9ed7a7bc7f82c2e52 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Jul 2024 16:12:53 -0700 Subject: [PATCH 014/202] [FFI] Add python signal handler for ctypes FFI (#17181) --- python/tvm/_ffi/_ctypes/packed_func.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6465e0335db0..5f3aa04914be 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -195,6 +195,7 @@ class PackedFuncBase(object): """Function base.""" __slots__ = ["handle", "is_global"] + # pylint: disable=no-member def __init__(self, handle, is_global): """Initialize the function with handle @@ -342,6 +343,7 @@ def _init_pythonapi_inc_def_ref(): register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef) register_func(c_str("PyGILState_Ensure"), ctypes.pythonapi.PyGILState_Ensure) register_func(c_str("PyGILState_Release"), ctypes.pythonapi.PyGILState_Release) + register_func(c_str("PyErr_CheckSignals"), ctypes.pythonapi.PyErr_CheckSignals) _init_pythonapi_inc_def_ref() From 9b0998463698c34906bcbc431e43adc4eed70759 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Subramaniam Date: Tue, 23 Jul 2024 04:43:43 +0530 Subject: [PATCH 015/202] [Hexagon] [CMake] Fix v66 build issue (#17169) This patch fixes the issue mentioned in [#17163](https://github.com/apache/tvm/issues/17163) --- apps/hexagon_api/CMakeLists.txt | 7 +++++- cmake/modules/Hexagon.cmake | 44 ++++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index 3b5300ac5582..f7144835dbe0 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -114,6 +114,11 @@ if(DEFINED USE_HEXAGON_GTEST) set(GTEST_FLAG "-DUSE_HEXAGON_GTEST=${USE_HEXAGON_GTEST}") endif() +if(NOT DEFINED USE_HEXAGON_QHL) + # USE_HEXAGON_QHL defaults to ON for rpc runtime if not explicitly set to OFF + set(USE_HEXAGON_QHL ON) +endif() + ExternalProject_Add(hexagon_tvm_runtime_rpc SOURCE_DIR "${TVM_SOURCE_DIR}" BUILD_COMMAND $(MAKE) runtime hexagon_rpc_sim @@ -135,7 +140,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DUSE_ALTERNATIVE_LINKER=OFF" "-DUSE_CUSTOM_LOGGING=ON" - "-DUSE_HEXAGON_QHL=ON" + "-DUSE_HEXAGON_QHL=${USE_HEXAGON_QHL}" "-DUSE_RANDOM=ON" "${GTEST_FLAG}" INSTALL_COMMAND "" diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 21a909e315ac..75b0094ed611 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -134,11 +134,22 @@ else() ) endif() +set(htp_supported_archs "v68" "v69" "v73" "v75") +list(FIND htp_supported_archs "${USE_HEXAGON_ARCH}" supported_arch_index) +if(${supported_arch_index} EQUAL -1) + # Exclude User DMA files when building for archs below v68 + list(REMOVE_ITEM RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon/hexagon_user_dma.cc") +endif() + if(BUILD_FOR_HEXAGON) if(DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) file_glob_append(RUNTIME_HEXAGON_SRCS "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/hexagon/*.cc" ) + if(${supported_arch_index} EQUAL -1) + # Exclude User DMA files when building for archs below v68 + list(REMOVE_ITEM RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon/hexagon_user_dma_tests.cc") + endif() endif() get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" SDK_INCLUDE SDK_INCLUDE_DIRS @@ -176,24 +187,27 @@ if(BUILD_FOR_HEXAGON) endif() - # Hand-written ops - file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/ops/*.cc" - ) + # Exclude HVX implementation files when building for archs below v68 + if(${supported_arch_index} GREATER -1) + # Hand-written ops + file_glob_append(RUNTIME_HEXAGON_SRCS + "${TVMRT_SOURCE_DIR}/hexagon/ops/*.cc" + ) - include_directories( - "${TVMRT_SOURCE_DIR}/hexagon/ops" - ) + include_directories( + "${TVMRT_SOURCE_DIR}/hexagon/ops" + ) - set_source_files_properties( - "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_quant_hvx.cc" - PROPERTIES COMPILE_FLAGS "-mhvx" - ) + set_source_files_properties( + "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_quant_hvx.cc" + PROPERTIES COMPILE_FLAGS "-mhvx" + ) - set_source_files_properties( - "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_fp16_hvx.cc" - PROPERTIES COMPILE_FLAGS "-mhvx" - ) + set_source_files_properties( + "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_fp16_hvx.cc" + PROPERTIES COMPILE_FLAGS "-mhvx" + ) + endif() # Include hexagon external library runtime sources if(USE_HEXAGON_EXTERNAL_LIBS) From 432f305ce188f9a679965fb32d1141f92d25b8d0 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 23 Jul 2024 08:13:57 +0900 Subject: [PATCH 016/202] Add `packaging` to `python/gen_requirements.py` (#17188) add packaging as a base dependency --- python/gen_requirements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 0c8200f60b10..5919d2a9c787 100644 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -68,6 +68,7 @@ "decorator", "ml_dtypes", "numpy", + "packaging", "psutil", "scipy", "tornado", From 162d43a9978f3d31cfd48e3e0ad70ffbba5d22ec Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 23 Jul 2024 13:23:12 +0900 Subject: [PATCH 017/202] [Relax][PyTorch] Add support for torch.einsum (#17186) Add torch.einsum support to Relax PyTorch Frontend. --- .../tvm/relax/frontend/torch/fx_translator.py | 9 ++++ tests/python/relax/test_frontend_from_fx.py | 43 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f9a5d9c33f02..e6b39c3eee0e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -518,6 +518,14 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _einsum(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) + return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) + ########## Manipulation ########## def _cat(self, node: fx.node.Node) -> relax.Var: @@ -1482,6 +1490,7 @@ def create_convert_map(self): "max": self._max, "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, + "einsum": self._einsum, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 46c079aa99cc..b4ac3fa60ce9 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -650,6 +650,49 @@ def main( ) +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum("i,j->ij", x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((5, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) + gv: R.Tensor((5, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1) + verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) + + def test_relu(): class ReLU0(Module): def __init__(self): From e6476847753c80e054719ac47bc2091c888418b6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 23 Jul 2024 21:39:48 +0900 Subject: [PATCH 018/202] [MetaSchedule] Replace `xgboost.rabit` with `xgboost.collective` because it's deprecated (#17166) * use collective instead of rabit * can work with xgb==1.4.2 in CI --- python/tvm/meta_schedule/cost_model/xgb_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 6b6b7a2dc1ed..aaee58fc94c8 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -755,7 +755,12 @@ def _fmt_metric(value, show_stdv=True): raise ValueError("wrong metric value", value) import xgboost as xgb - from xgboost import rabit # type: ignore + + # make it compatible with xgboost<1.7 + try: + from xgboost import rabit as collective # type: ignore + except ImportError: + from xgboost import collective # type: ignore try: from xgboost.training import aggcv # type: ignore @@ -841,7 +846,7 @@ def _fmt_metric(value, show_stdv=True): elif epoch - best_iteration >= self.early_stopping_rounds: best_msg = self.state["best_msg"] - if self.verbose_eval and rabit.get_rank() == 0: + if self.verbose_eval and collective.get_rank() == 0: logger.debug("XGB stopped. Best iteration: %s ", best_msg) # instead of raising EarlyStopException, returning True to end the training return True From bbc97c77fbd890361a8705c4450057c5c1bfd0db Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 23 Jul 2024 05:52:57 -0700 Subject: [PATCH 019/202] [Disco] Group-wise operation (#17180) This PR introduces the group attribute into Disco, so that group wise allreduce and allgather is enabled. --- include/tvm/relax/attrs/ccl.h | 18 ++ include/tvm/runtime/disco/builtin.h | 15 +- include/tvm/runtime/disco/disco_worker.h | 8 +- include/tvm/runtime/disco/session.h | 8 +- python/tvm/exec/disco_worker.py | 15 +- python/tvm/relax/frontend/nn/op.py | 13 +- python/tvm/relax/op/ccl/ccl.py | 24 +-- .../tvm/relax/transform/legalize_ops/ccl.py | 10 +- python/tvm/runtime/disco/process_pool.py | 10 +- python/tvm/runtime/disco/session.py | 101 ++++++++--- src/relax/op/ccl/ccl.cc | 22 ++- src/relax/op/ccl/ccl.h | 4 +- src/runtime/disco/builtin.cc | 34 ++-- src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc | 4 +- .../disco/cuda_ipc/custom_allreduce.cc | 4 +- src/runtime/disco/disco_worker_thread.h | 4 +- src/runtime/disco/loader.cc | 8 +- src/runtime/disco/nccl/nccl.cc | 102 ++++++----- src/runtime/disco/nccl/nccl_context.h | 13 +- src/runtime/disco/process_session.cc | 21 ++- src/runtime/disco/threaded_session.cc | 16 +- tests/python/disco/test_callback.py | 11 +- tests/python/disco/test_ccl.py | 168 +++++++++++++++++- tests/python/disco/test_loader.py | 3 +- tests/python/disco/test_session.py | 20 +-- ...ed_transform_lower_global_to_local_view.py | 4 +- .../relax/test_transform_legalize_ops_ccl.py | 18 +- 27 files changed, 491 insertions(+), 187 deletions(-) diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index 42cec88de673..de043f92be82 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -32,14 +32,32 @@ namespace relax { /*! \brief Attributes used in allreduce operators */ struct AllReduceAttrs : public tvm::AttrsNode { String op_type; + bool in_group; TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") { TVM_ATTR_FIELD(op_type).describe( "The type of reduction operation to be applied to the input data. Now only sum is " "supported."); + TVM_ATTR_FIELD(in_group).describe( + "Whether the reduction operation performs in group or globally or in group as default."); } }; // struct AllReduceAttrs +/*! \brief Attributes used in allgather operators */ +struct AllGatherAttrs : public tvm::AttrsNode { + int num_workers; + bool in_group; + + TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") { + TVM_ATTR_FIELD(num_workers) + .describe( + "The number of workers, also the number of parts the given buffer should be chunked " + "into."); + TVM_ATTR_FIELD(in_group).describe( + "Whether the allgather operation performs in group or globally or in group as default."); + } +}; // struct AllGatherAttrs + /*! \brief Attributes used in scatter operators */ struct ScatterCollectiveAttrs : public tvm::AttrsNode { int num_workers; diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index cf9967dbfe76..7d15e35fbdbc 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -75,35 +75,40 @@ TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device devic * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on * \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max) + * \param in_group Whether the allreduce operation performs globally or in group as default. * \param recv The array receives the outcome of allreduce */ -TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv); +TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv); /*! * \brief Perform an allgather operation using the underlying communication library * \param send The array send to perform allgather on + * \param in_group Whether the allgather operation performs globally or in group as default. * \param recv The array receives the outcome of allgather */ -TVM_DLL void AllGather(NDArray send, NDArray recv); +TVM_DLL void AllGather(NDArray send, bool in_group, NDArray recv); /*! * \brief Perform a broadcast operation from worker-0 * \param send The buffer to be broadcasted + * \param in_group Whether the broadcast operation performs globally or in group as default. * \param recv The buffer receives the broadcasted array */ -TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv); +TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv); /*! * \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts. * \param send For worker-0, it must be provided, and otherwise, the buffer must be None. * The buffer will be divided into equal parts and sent to each worker accordingly. + * \param in_group Whether the scatter operation performs globally or in group as default. * \param recv The receiving buffer, which must not be None. */ -TVM_DLL void ScatterFromWorker0(Optional send, NDArray recv); +TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray recv); /*! * \brief Perform a gather operation to worker-0. * \param send The sending buffer, which must not be None. + * \param in_group Whether the gather operation performs globally or in group as default. * \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The * receiving buffer will be divided into equal parts and receive from each worker accordingly. */ -TVM_DLL void GatherToWorker0(NDArray send, Optional recv); +TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional recv); /*! * \brief Receive a buffer from worker-0. No-op if the current worker is worker-0. * \param buffer The buffer to be received diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 14f8f238074f..301b5b8d626b 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -44,14 +44,16 @@ class DiscoWorker { * \brief Construct a worker. * \param worker_id The id of the worker. * \param num_workers The number of the workers. + * \param num_groups The number of the worker groups. * \param worker_zero_data The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. * \param channel The communication channel between the worker and the controler. */ - explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData* worker_zero_data, - DiscoChannel* channel) + explicit DiscoWorker(int worker_id, int num_workers, int num_groups, + WorkerZeroData* worker_zero_data, DiscoChannel* channel) : worker_id(worker_id), num_workers(num_workers), + num_groups(num_groups), default_device(Device{DLDeviceType::kDLCPU, 0}), worker_zero_data(worker_zero_data), channel(channel), @@ -68,6 +70,8 @@ class DiscoWorker { int worker_id; /*! \brief Total number of workers */ int num_workers; + /*! \brief Total number of workers */ + int num_groups; /*! \brief The default device to allocate data if not specified */ Device default_device; /*! \brief The name of the underlying collective communication library. */ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 71fcce75b292..97fa79096d63 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -264,11 +264,13 @@ class Session : public ObjectRef { /*! * \brief Create a session backed by a thread pool of workers * \param num_workers The number of workers. + * \param num_groups The number of worker groups. */ - TVM_DLL static Session ThreadedSession(int num_workers); + TVM_DLL static Session ThreadedSession(int num_workers, int num_groups); /*! * \brief Create a session backed by pipe-based multiprocessing * \param num_workers The number of workers. + * \param num_groups The number of worker groups. * \param process_pool_creator The name of a global function that takes `num_workers` as an input, * and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None. * When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple @@ -277,8 +279,8 @@ class Session : public ObjectRef { * \note Worker-0 is always co-located with the controler as a separate thread, and therefore * worker-0 does not exist in the process pool. */ - TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator, - String entrypoint); + TVM_DLL static Session ProcessSession(int num_workers, int num_groups, + String process_pool_creator, String entrypoint); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index 76ce0ff9936f..b1f1554b56f9 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -99,22 +99,23 @@ def fget_item(param_name: str, param_index: int) -> NDArray: def main(): """Main worker function""" - if len(sys.argv) != 5: - print("Usage: ") + if len(sys.argv) != 6: + print("Usage: ") return worker_id = int(sys.argv[1]) num_workers = int(sys.argv[2]) + num_groups = int(sys.argv[3]) if sys.platform == "win32": import msvcrt # pylint: disable=import-outside-toplevel,import-error - reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) - writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + reader = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + writer = msvcrt.open_osfhandle(int(sys.argv[5]), os.O_BINARY) else: - reader = int(sys.argv[3]) - writer = int(sys.argv[4]) + reader = int(sys.argv[4]) + writer = int(sys.argv[5]) worker_func = get_global_func("runtime.disco.WorkerProcess") - worker_func(worker_id, num_workers, reader, writer) + worker_func(worker_id, num_workers, num_groups, reader, writer) if __name__ == "__main__": diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index ec072f663cd5..e1ba4483c741 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1671,16 +1671,21 @@ def interpolate( ) -def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): +def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name="ccl_allreduce"): """CCL Allreduce operator Parameters ---------- - x : Tensor + x : relax.Expr The input tensor. - op_type: str + + op_type : str The type of reduction operation to be applied to the input data. Now "sum", "prod", "min", "max" and "avg" are supported. + + in_group : bool + Whether the reduction operation performs globally or in group as default. + name : str Name hint for this operation. @@ -1689,7 +1694,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): result : Tensor The result tensor of allreduce. """ - return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name) + return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name) def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py index 21c7946120a7..982c04802156 100644 --- a/python/tvm/relax/op/ccl/ccl.py +++ b/python/tvm/relax/op/ccl/ccl.py @@ -15,25 +15,26 @@ # specific language governing permissions and limitations # under the License. """Relax Collective Communications Library (CCL) operators""" -from typing import Union -from tvm.relax import PrimValue from . import _ffi_api from ...expr import Expr -from ....ir import PrimExpr -def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name +def allreduce(x, op_type: str = "sum", in_group: bool = True): # pylint: disable=invalid-name """Allreduce operator Parameters ---------- x : relax.Expr The input tensor. - op_type: str + + op_type : str The type of reduction operation to be applied to the input data. Now "sum", "prod", "min", "max" and "avg" are supported. + in_group : bool + Whether the reduction operation performs globally or in group as default. + Returns ------- result : relax.Expr @@ -44,10 +45,10 @@ def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name "Allreduce only supports limited reduction operations, " f"including {supported_op_types}, but got {op_type}." ) - return _ffi_api.allreduce(x, op_type) # type: ignore # pylint: disable=no-member + return _ffi_api.allreduce(x, op_type, in_group) # type: ignore # pylint: disable=no-member -def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disable=invalid-name +def allgather(x, num_workers: int, in_group: bool = True): # pylint: disable=invalid-name """AllGather operator Parameters @@ -55,17 +56,18 @@ def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disab x : relax.Expr The input tensor. - num_worker : Union[int, PrimExpr, PrimValue] + num_worker : int The number of workers to gather data from. + in_group : bool + Whether the gather operation performs globally or in group as default. + Returns ------- result : relax.Expr The result of allgather. """ - if not isinstance(num_workers, PrimValue): - num_workers = PrimValue(num_workers) - return _ffi_api.allgather(x, num_workers) # type: ignore # pylint: disable=no-member + return _ffi_api.allgather(x, num_workers, in_group) # type: ignore # pylint: disable=no-member def broadcast_from_worker0(x: Expr) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py index ae0be3c228f5..364dee750e8b 100644 --- a/python/tvm/relax/transform/legalize_ops/ccl.py +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -41,7 +41,7 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr: ) return call_dps_packed( "runtime.disco.allreduce", - [call.args[0], ShapeExpr([op_type_map[op_type_str]])], + [call.args[0], ShapeExpr([op_type_map[op_type_str]]), call.attrs.in_group], out_sinfo=call.args[0].struct_info, ) @@ -57,12 +57,12 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr: arg_shape = arg_sinfo.shape.struct_info for i, shape_value in enumerate(arg_shape.values): if i == 0: - output_shape.append(shape_value * call.args[1].value) + output_shape.append(shape_value * call.attrs.num_workers) else: output_shape.append(shape_value) return call_dps_packed( "runtime.disco.allgather", - call.args[0], + [call.args[0], call.attrs.in_group], out_sinfo=TensorStructInfo( shape=output_shape, dtype=arg_sinfo.dtype, @@ -75,7 +75,7 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr: def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: return call_dps_packed( "runtime.disco.broadcast_from_worker0", - call.args[0], + [call.args[0], False], out_sinfo=call.args[0].struct_info, ) @@ -116,7 +116,7 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: output_shape = output_shape[1:] return call_dps_packed( "runtime.disco.scatter_from_worker0", - transpose_var, + [transpose_var, False], out_sinfo=TensorStructInfo( shape=output_shape, dtype=call.args[0].struct_info.dtype, diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py index 1ad8659d6088..95969e038e0f 100644 --- a/python/tvm/runtime/disco/process_pool.py +++ b/python/tvm/runtime/disco/process_pool.py @@ -38,6 +38,9 @@ class DiscoPopenWorker: num_workers : int The total number of workers. + num_groups : int + The total number of worker groups. + stdout: Union[None, int, IO[Any]] The standard output streams handler specified for the popen process. @@ -49,12 +52,14 @@ def __init__( # pylint: disable=too-many-arguments self, worker_id: int, num_workers: int, + num_groups: int, entrypoint: str = "tvm.exec.disco_worker", stdout=None, stderr=None, ): self.worker_id = worker_id self.num_workers = num_workers + self.num_groups = num_groups self.entrypoint = entrypoint self._proc = None self._stdout = stdout @@ -118,6 +123,7 @@ def start(self): self.entrypoint, str(self.worker_id), str(self.num_workers), + str(self.num_groups), ] if sys.platform == "win32": import msvcrt # pylint: disable=import-error,import-outside-toplevel @@ -172,9 +178,9 @@ def _kill_child_processes(pid): @register_func("runtime.disco.create_process_pool") -def _create_process_pool(num_workers: int, entrypoint: str): +def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str): """Create a process pool where the workers' are [1, num_workers).""" - pool = [DiscoPopenWorker(i, num_workers, entrypoint) for i in range(1, num_workers)] + pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in range(1, num_workers)] def result_func(worker_id: int): nonlocal pool diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index ddde1bc1f323..38c4f2a2354c 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -66,6 +66,7 @@ def debug_copy_from( ---------- worker_id : int The id of the worker to be copied to. + value : Union[numpy.ndarray, NDArray] The value to be copied. """ @@ -121,6 +122,7 @@ def empty( dtype: str, device: Optional[Device] = None, worker0_only: bool = False, + in_group: bool = True, ) -> DRef: """Create an empty NDArray on all workers and attach them to a DRef. @@ -139,6 +141,11 @@ def empty( If False (default), allocate an array on each worker. If True, only allocate an array on worker0. + in_group: bool + Take effective when `worker0_only` is True. If True (default), + allocate an array on each first worker in each group. If + False, only allocate an array on worker0 globally. + Returns ------- array : DRef @@ -148,7 +155,7 @@ def empty( if device is None: device = Device(device_type=0, device_id=0) func = self._get_cached_method("runtime.disco.empty") - return func(ShapeTuple(shape), dtype, device, worker0_only) + return func(ShapeTuple(shape), dtype, device, worker0_only, in_group) def shutdown(self): """Shut down the Disco session""" @@ -244,6 +251,7 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: ---------- host_array : numpy.ndarray The array to be copied to worker-0. + remote_array : NDArray The NDArray on worker-0. """ @@ -255,11 +263,9 @@ def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = N Parameters ---------- host_array : NDArray - The array to be copied to worker-0. remote_array : Optiona[DRef] - The destination NDArray on worker-0. Returns @@ -289,6 +295,7 @@ def load_vm_module( ---------- path : str The path to the VM module file. + device : Optional[Device] = None The device to load the VM module to. Default to the default device of each worker. @@ -312,6 +319,7 @@ def init_ccl(self, ccl: str, *device_ids): - nccl - rccl - mpi + *device_ids : int The device IDs to be used by the underlying communication library. """ @@ -319,20 +327,23 @@ def init_ccl(self, ccl: str, *device_ids): _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member self._clear_ipc_memory_pool() - def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + def broadcast( + self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + ) -> DRef: """Broadcast an array to all workers Parameters ---------- src: Union[np.ndarray, NDArray] - The array to be broadcasted. dst: Optional[DRef] - The output array. If None, an array matching the shape and dtype of `src` will be allocated on each worker. + in_group: bool + Whether the broadcast operation performs globally or in group as default. + Returns ------- output_array: DRef @@ -349,38 +360,48 @@ def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) dst = self.empty(src.shape, src.dtype) src_dref = self.copy_to_worker_0(src) - self.broadcast_from_worker0(src_dref, dst) + self.broadcast_from_worker0(src_dref, dst, in_group) return dst - def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: + def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> DRef: """Broadcast an array from worker-0 to all other workers. Parameters ---------- - array : DRef - The array to be broadcasted in-place + src: Union[np.ndarray, NDArray] + The array to be broadcasted. + + dst: Optional[DRef] + The output array. If None, an array matching the shape + and dtype of `src` will be allocated on each worker. + + in_group: bool + Whether the broadcast operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.broadcast_from_worker0") - func(src, dst) + func(src, in_group, dst) - def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + def scatter( + self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + ) -> DRef: """Scatter an array across all workers Parameters ---------- src: Union[np.ndarray, NDArray] - The array to be scattered. The first dimension of this array, `src.shape[0]`, must be equal to the number of workers. dst: Optional[DRef] - The output array. If None, an array with compatible shape and the same dtype as `src` will be allocated on each worker. + in_group: bool + Whether the scatter operation performs globally or in group as default. + Returns ------- output_array: DRef @@ -399,41 +420,54 @@ def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) - dst = self.empty(src.shape[1:], src.dtype) src_dref = self.copy_to_worker_0(src) - self.scatter_from_worker0(src_dref, dst) + self.scatter_from_worker0(src_dref, dst, in_group) return dst - def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None: + def scatter_from_worker0(self, from_array: DRef, to_array: DRef, in_group: bool = True) -> None: """Scatter an array from worker-0 to all other workers. Parameters ---------- - from_array : DRef - The array to be scattered from. - to_array : DRef - The array to be scattered to. + src: Union[np.ndarray, NDArray] + The array to be scattered. The first dimension of this + array, `src.shape[0]`, must be equal to the number of + workers. + + dst: Optional[DRef] + The output array. If None, an array with compatible shape + and the same dtype as `src` will be allocated on each + worker. + + in_group: bool + Whether the scatter operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.scatter_from_worker0") - func(from_array, to_array) + func(from_array, in_group, to_array) - def gather_to_worker0(self, from_array: DRef, to_array: DRef) -> None: + def gather_to_worker0(self, from_array: DRef, to_array: DRef, in_group: bool = True) -> None: """Gather an array from all other workers to worker-0. Parameters ---------- from_array : DRef The array to be gathered from. + to_array : DRef The array to be gathered to. + + in_group: bool + Whether the gather operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.gather_to_worker0") - func(from_array, to_array) + func(from_array, in_group, to_array) def allreduce( self, src: DRef, dst: DRef, op: str = "sum", # pylint: disable=invalid-name + in_group: bool = True, ) -> DRef: """Perform an allreduce operation on an array. @@ -441,6 +475,7 @@ def allreduce( ---------- array : DRef The array to be reduced. + op : str = "sum" The reduce operation to be performed. Available options are: - "sum" @@ -448,17 +483,21 @@ def allreduce( - "min" - "max" - "avg" + + in_group : bool + Whether the reduce operation performs globally or in group as default. """ if op not in REDUCE_OPS: raise ValueError(f"Unsupported reduce op: {op}. Available ops are: {REDUCE_OPS.keys()}") op = ShapeTuple([REDUCE_OPS[op]]) func = self._get_cached_method("runtime.disco.allreduce") - func(src, op, dst) + func(src, op, in_group, dst) def allgather( self, src: DRef, dst: DRef, + in_group: bool = True, ) -> DRef: """Perform an allgather operation on an array. @@ -466,11 +505,15 @@ def allgather( ---------- src : DRef The array to be gathered from. + dst : DRef The array to be gathered to. + + in_group : bool + Whether the reduce operation performs globally or in group as default. """ func = self._get_cached_method("runtime.disco.allgather") - func(src, dst) + func(src, in_group, dst) def _clear_ipc_memory_pool(self): # Clear the IPC memory allocator when the allocator exists. @@ -483,11 +526,12 @@ def _clear_ipc_memory_pool(self): class ThreadedSession(Session): """A Disco session backed by multi-threading.""" - def __init__(self, num_workers: int) -> None: + def __init__(self, num_workers: int, num_groups: int = 1) -> None: """Create a disco session backed by multiple threads in the same process.""" self.__init_handle_by_constructor__( _ffi_api.SessionThreaded, # type: ignore # pylint: disable=no-member num_workers, + num_groups, ) @@ -495,10 +539,13 @@ def __init__(self, num_workers: int) -> None: class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" - def __init__(self, num_workers: int, entrypoint: str = "tvm.exec.disco_worker") -> None: + def __init__( + self, num_workers: int, num_groups: int = 1, entrypoint: str = "tvm.exec.disco_worker" + ) -> None: self.__init_handle_by_constructor__( _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member num_workers, + num_groups, "runtime.disco.create_process_pool", entrypoint, ) diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index c0fe6f4d88d7..092727cb5115 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -27,9 +27,10 @@ namespace relax { /* relax.ccl.allreduce */ TVM_REGISTER_NODE_TYPE(AllReduceAttrs); -Expr allreduce(Expr x, String op_type) { +Expr allreduce(Expr x, String op_type, bool in_group) { ObjectPtr attrs = make_object(); attrs->op_type = std::move(op_type); + attrs->in_group = std::move(in_group); static const Op& op = Op::Get("relax.ccl.allreduce"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); @@ -51,19 +52,24 @@ TVM_REGISTER_OP("relax.ccl.allreduce") .set_attr("FPurity", Bool(true)); /* relax.ccl.allgather */ -Expr allgather(Expr x, Expr num_workers) { +TVM_REGISTER_NODE_TYPE(AllGatherAttrs); + +Expr allgather(Expr x, int num_workers, bool in_group) { + ObjectPtr attrs = make_object(); + attrs->num_workers = std::move(num_workers); + attrs->in_group = std::move(in_group); + static const Op& op = Op::Get("relax.ccl.allgather"); - return Call(op, {std::move(x), std::move(num_workers)}); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); } TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { - CHECK_EQ(call->args.size(), 2); - auto input_sinfo = Downcast(call->args[0]->struct_info_); - auto num_workers_sinfo = Downcast(call->args[1]->struct_info_); + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - auto num_workers = num_workers_sinfo->value; + const auto* attrs = call->attrs.as(); + int num_workers = attrs->num_workers; DataType output_dtype = input_sinfo->dtype; auto input_shape = input_sinfo->GetShape(); @@ -71,7 +77,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { return input_sinfo; } Array output_shape = input_shape.value(); - output_shape.Set(0, floor(output_shape[0] * num_workers.value())); + output_shape.Set(0, floor(output_shape[0] * num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h index 3e7f0220c9dc..82ea3935675d 100644 --- a/src/relax/op/ccl/ccl.h +++ b/src/relax/op/ccl/ccl.h @@ -33,10 +33,10 @@ namespace tvm { namespace relax { /*! \brief AllReduce. */ -Expr allreduce(Expr data, String op_type); +Expr allreduce(Expr data, String op_type, bool in_group); /*! \brief AllGather. */ -Expr allgather(Expr data, Expr num_workers); +Expr allgather(Expr data, int num_workers, bool in_group); /*! \brief Broadcast data from worker-0 to all other workers. */ Expr broadcast_from_worker0(Expr data); diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 26d1c22ee975..0cb2ee6f5d6b 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -79,22 +79,24 @@ const PackedFunc& GetCCLFunc(const char* name) { return *pf; } -void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { - GetCCLFunc("allreduce")(send, static_cast(reduce_kind), recv); +void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { + GetCCLFunc("allreduce")(send, static_cast(reduce_kind), in_group, recv); } -void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); } +void AllGather(NDArray send, bool in_group, NDArray recv) { + GetCCLFunc("allgather")(send, in_group, recv); +} -TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) { - GetCCLFunc("broadcast_from_worker0")(send, recv); +TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv) { + GetCCLFunc("broadcast_from_worker0")(send, in_group, recv); } -TVM_DLL void ScatterFromWorker0(Optional send, NDArray recv) { - GetCCLFunc("scatter_from_worker0")(send, recv); +TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { + GetCCLFunc("scatter_from_worker0")(send, in_group, recv); } -void GatherToWorker0(NDArray send, Optional recv) { - GetCCLFunc("gather_to_worker0")(send, recv); +void GatherToWorker0(NDArray send, bool in_group, Optional recv) { + GetCCLFunc("gather_to_worker0")(send, in_group, recv); } void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); } @@ -110,9 +112,13 @@ void SyncWorker() { TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); TVM_REGISTER_GLOBAL("runtime.disco.empty") - .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, - bool worker0_only) -> Optional { - if (worker0_only && WorkerId()) { + .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, bool worker0_only, + bool in_group) -> Optional { + int worker_id = WorkerId(); + int group_size = + DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; + bool is_worker0 = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + if (worker0_only && !is_worker0) { return NullOpt; } else { return DiscoEmptyNDArray(shape, dtype, device); @@ -120,10 +126,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.empty") }); TVM_REGISTER_GLOBAL("runtime.disco.allreduce") - .set_body_typed([](NDArray send, ShapeTuple reduce_kind, NDArray recv) { + .set_body_typed([](NDArray send, ShapeTuple reduce_kind, bool in_group, NDArray recv) { int kind = IntegerFromShapeTuple(reduce_kind); CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - AllReduce(send, static_cast(kind), recv); + AllReduce(send, static_cast(kind), in_group, recv); }); TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0); diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index fec5abec86b0..490217d62c79 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -47,8 +47,8 @@ std::vector AllGatherIPCHandles(nccl::CCLThreadLocalContext* CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE)); CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers)); CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, cudaMemcpyHostToDevice)); - NCCL_CALL( - ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->comm, /*stream=*/nullptr)); + NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->global_comm, + /*stream=*/nullptr)); std::vector serial_handles(CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 0); CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, cudaMemcpyDefault)); diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index 98fd777b8364..d969005f9476 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -65,6 +65,8 @@ inline bool CanApplyTwoShotAllReduce(int64_t num_elements, DLDataType dtype, int void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { int64_t num_elements = TensorSize(send); nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_groups, 1) + << "Custom AllReduce for multiple group is not yet implemented."; tensorrt_llm::AllReduceStrategyType strategy_ = static_cast(strategy); @@ -79,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements, /*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)), - /*op=*/ncclSum, ctx->comm, stream)); + /*op=*/ncclSum, ctx->global_comm, stream)); return; } diff --git a/src/runtime/disco/disco_worker_thread.h b/src/runtime/disco/disco_worker_thread.h index 67742cdd0408..8d6b44396f4d 100644 --- a/src/runtime/disco/disco_worker_thread.h +++ b/src/runtime/disco/disco_worker_thread.h @@ -47,12 +47,14 @@ class DiscoWorkerThread { * \brief Construct a worker thread. * \param worker_id The id of the worker. * \param num_workers The total number of workers. + * \param num_groups The total number of worker groups. * \param worker_zero_data_ The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. * \note This method is implemented in threaded worker, because it depends on creation of a * sub-class of DiscoChannel, DiscoThreadChannel, which is hidden from the public interface. */ - explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData* worker_zero_data_); + explicit DiscoWorkerThread(int worker_id, int num_workers, int num_groups, + WorkerZeroData* worker_zero_data_); /*! \brief Move constructor. */ explicit DiscoWorkerThread(DiscoWorkerThread&& other) diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 7a5d97894680..efe42539cb56 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -326,19 +326,19 @@ NDArray ShardLoaderObj::Load(int weight_index) const { for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) { w = this->ApplyShardFunc(shard_func, w); } - ScatterFromWorker0(w, recv); + ScatterFromWorker0(w, /*in_group=*/false, recv); } else { - ScatterFromWorker0(NullOpt, recv); + ScatterFromWorker0(NullOpt, /*in_group=*/false, recv); } return recv; } else { if (worker_id == 0) { NDArray w = LoadDirect(weight_index); - BroadcastFromWorker0(w, w); + BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } else { NDArray w = NDArray::Empty(param->shape, param->dtype, device); - BroadcastFromWorker0(w, w); + BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } } diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index bba42ed3bdfe..2d2c528b5291 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -72,9 +72,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { << "ValueError: The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " << unique_id_bytes.size() << "."; - CHECK(!ctx->comm) << "Cannot initialize CCL, " - << "the previous thread-global comm still exists, " - << "and has not been destructed"; + CHECK(!ctx->global_comm) << "Cannot initialize CCL, " + << "the previous thread-global comm still exists, " + << "and has not been destructed"; + CHECK(!ctx->group_comm) << "Cannot initialize CCL, " + << "the previous thread-group comm still exists, " + << "and has not been destructed"; CHECK(!ctx->default_stream) << "Cannot initialize CCL, " << "the previous thread-global stream still exists, " << "and has not been destructed"; @@ -96,34 +99,41 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { // Initialize the communicator ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); - NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id, worker->worker_id)); + int group_size = worker->num_workers / worker->num_groups; + NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); + NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, + worker->worker_id % group_size, &ctx->group_comm, NULL)); } -void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { +void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(send->dtype)), - /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream)); + /*op=*/AsNCCLRedOp(reduce_kind), + in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void AllGather(NDArray send, NDArray recv) { +void AllGather(NDArray send, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllGather(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream)); + /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void BroadcastFromWorker0(Optional send, NDArray recv) { +void BroadcastFromWorker0(Optional send, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); const void* send_data = [&]() -> const void* { - int worker_id = ctx->worker->worker_id; - if (worker_id == 0) { + if (is_sender) { CHECK(send.defined()); CHECK(send.value().Shape()->Product() == recv.Shape()->Product()); return send.value()->data; @@ -136,25 +146,28 @@ void BroadcastFromWorker0(Optional send, NDArray recv) { deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(recv->dtype)), - /*root=*/0, ctx->comm, stream)); + /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void ScatterFromWorker0(Optional send, NDArray recv) { +void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + int group_size = num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); - if (worker_id == 0) { + if (is_sender) { CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; NDArray buffer = send.value(); int64_t numel = buffer.Shape()->Product(); - CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_workers << " workers."; + CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); - int64_t numel_per_shard = numel / num_workers; + int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); CHECK_EQ(numel_per_shard, recv.Shape()->Product()) << "ValueError: The number of elements in buffer `recv` must be the same as each shard " @@ -163,40 +176,45 @@ void ScatterFromWorker0(Optional send, NDArray recv) { << numel << ", but `recv.size` is " << recv.Shape()->Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); - for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, stream)); + for (int i = 0; i < num_receiver; ++i) { + NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, + in_group ? ctx->group_comm : ctx->global_comm, stream)); data += bytes_per_shard; } } else { if (send.defined()) { - LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got " - "send = " + LOG(WARNING) << "ValueError: buffer `send` must be None when (worker_id != 0 && !in_group) " + "or (worker_id % group_size != 0 && in_group). However, got send = " << send.get() << ". This will be ignored."; } NCCL_CALL(ncclGroupStart()); } int64_t numel = recv.Shape()->Product(); DataType dtype(recv->dtype); - NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, stream)); + NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, + in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void GatherToWorker0(NDArray send, Optional recv) { +void GatherToWorker0(NDArray send, bool in_group, Optional recv) { CHECK(send.defined()) << "ValueError: buffer `send` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + int group_size = num_workers / ctx->worker->num_groups; + bool is_sender = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); + int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); - if (worker_id == 0) { + if (is_sender) { CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; NDArray buffer = recv.value(); int64_t numel = buffer.Shape()->Product(); - CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_workers << " workers."; + CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); - int64_t numel_per_shard = numel / num_workers; + int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); CHECK_EQ(numel_per_shard, send.Shape()->Product()) << "ValueError: The number of elements in buffer `send` must be the same as each shard " @@ -205,21 +223,23 @@ void GatherToWorker0(NDArray send, Optional recv) { << numel << ", but `send.size` is " << send.Shape()->Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); - for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, stream)); + for (int i = 0; i < num_receiver; ++i) { + NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, + in_group ? ctx->group_comm : ctx->global_comm, stream)); data += bytes_per_shard; } } else { if (recv.defined()) { - LOG(WARNING) << "ValueError: buffer `recv` must be None when worker_id != 0. However, got " - "recv = " + LOG(WARNING) << "ValueError: buffer `recv` must be None when (worker_id != 0 && !in_group) " + "or (worker_id % group_size != 0 && in_group). However, got recv = " << recv.get() << ". This will be ignored."; } NCCL_CALL(ncclGroupStart()); } int64_t numel = send.Shape()->Product(); DataType dtype(send->dtype); - NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, stream)); + NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, + in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -230,7 +250,7 @@ void RecvFromWorker0(NDArray buffer) { << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0, - ctx->comm, stream)); + ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -248,12 +268,14 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_ty TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") .set_body_typed(InitCCLPerWorker); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") - .set_body_typed([](NDArray send, int kind, NDArray recv) { + .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - nccl::AllReduce(send, static_cast(kind), recv); + nccl::AllReduce(send, static_cast(kind), in_group, recv); }); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") - .set_body_typed([](NDArray send, NDArray recv) { nccl::AllGather(send, recv); }); + .set_body_typed([](NDArray send, bool in_group, NDArray recv) { + nccl::AllGather(send, in_group, recv); + }); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") .set_body_typed(BroadcastFromWorker0); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index 3fb281f2cb7c..730479b61ac0 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -121,14 +121,19 @@ struct CCLThreadLocalContext { DiscoWorker* worker = nullptr; int device_id; deviceStream_t default_stream = nullptr; - ncclComm_t comm = nullptr; + ncclComm_t global_comm = nullptr; + ncclComm_t group_comm = nullptr; ~CCLThreadLocalContext() { Clear(); } void Clear() { - if (comm) { - NCCL_CALL(ncclCommDestroy(comm)); - comm = nullptr; + if (group_comm) { + NCCL_CALL(ncclCommDestroy(group_comm)); + group_comm = nullptr; + } + if (global_comm) { + NCCL_CALL(ncclCommDestroy(global_comm)); + global_comm = nullptr; } if (default_stream) { StreamDestroy(default_stream); diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 179010db8a23..7c8d0796dd81 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -154,9 +154,10 @@ class DiscoProcessChannel final : public DiscoChannel { class ProcessSessionObj final : public BcastSessionObj { public: - explicit ProcessSessionObj(int num_workers, PackedFunc process_pool) + explicit ProcessSessionObj(int num_workers, int num_groups, PackedFunc process_pool) : process_pool_(process_pool), - worker_0_(std::make_unique(0, num_workers, &worker_zero_data_)) { + worker_0_( + std::make_unique(0, num_workers, num_groups, &worker_zero_data_)) { std::vector read_fds; std::vector write_fds; read_fds.reserve(num_workers - 1); @@ -258,18 +259,24 @@ class ProcessSessionObj final : public BcastSessionObj { TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject); TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj); -Session Session::ProcessSession(int num_workers, String process_pool_creator, String entrypoint) { +Session Session::ProcessSession(int num_workers, int num_group, String process_pool_creator, + String entrypoint) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; const PackedFunc* pf = Registry::Get(process_pool_creator); CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator << " in the registry. Please check if it is registered."; - PackedFunc process_pool = (*pf)(num_workers, entrypoint); - auto n = make_object(num_workers, process_pool); + PackedFunc process_pool = (*pf)(num_workers, num_group, entrypoint); + auto n = make_object(num_workers, num_group, process_pool); return Session(n); } -void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t write_fd) { +void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_fd, + int64_t write_fd) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; DiscoProcessChannel channel(read_fd, write_fd); - DiscoWorker worker(worker_id, num_workers, nullptr, &channel); + DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel); worker.MainLoop(); } diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 22f906b809d2..cc9a311a6b3f 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -133,20 +133,20 @@ class DiscoThreadChannel final : public DiscoChannel { DiscoThreadedMessageQueue worker_to_controler_; }; -DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, +DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, int num_groups, WorkerZeroData* worker_zero_data_) : channel(std::make_unique()), - worker( - std::make_unique(worker_id, num_workers, worker_zero_data_, channel.get())), + worker(std::make_unique(worker_id, num_workers, num_groups, worker_zero_data_, + channel.get())), thread(std::make_unique([worker = this->worker.get()] { worker->MainLoop(); })) { } class ThreadedSessionObj final : public BcastSessionObj { public: - explicit ThreadedSessionObj(int num_workers) { + explicit ThreadedSessionObj(int num_workers, int num_groups) { for (int i = 0; i < num_workers; ++i) { WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr; - workers_.emplace_back(i, num_workers, data); + workers_.emplace_back(i, num_workers, num_groups, data); } } @@ -185,8 +185,10 @@ class ThreadedSessionObj final : public BcastSessionObj { TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj); -Session Session::ThreadedSession(int num_workers) { - ObjectPtr n = make_object(num_workers); +Session Session::ThreadedSession(int num_workers, int num_group) { + CHECK_EQ(num_workers % num_group, 0) + << "The number of workers should be divisible by the number of worker group."; + ObjectPtr n = make_object(num_workers, num_group); return Session(std::move(n)); } diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py index 6e2dc9b7470c..3f8d5e9e525b 100644 --- a/tests/python/disco/test_callback.py +++ b/tests/python/disco/test_callback.py @@ -30,16 +30,17 @@ @tvm.testing.requires_nccl def test_callback(): + """Simulate lazy loading of parameters in a callback + + The output of a lazy parameter loading, which would accept a + callback to load the parameters. + """ + @R.function def transform_params( rank_arg: R.Prim(value="rank"), fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object), ): - """Simulate lazy loading of parameters in a callback - - The output of a lazy parameter loading, which would accept a - callback to load the parameters. - """ rank = T.int64() A = fget_item(R.str("A"), R.prim_value(0)) diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 5831f245dfaf..6c63f64554a3 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -78,6 +78,42 @@ def test_allreduce(session_kind, ccl): np.testing.assert_equal(result, expected) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_allreduce(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + array_3 = np.arange(30, dtype="float32").reshape(5, 6) + array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6) + d_array_1 = sess.empty((3, 4), "float32") + d_array_2 = sess.empty((5, 6), "float32") + d_array_1.debug_copy_from(0, array_1) + d_array_1.debug_copy_from(1, array_2) + d_array_2.debug_copy_from(2, array_3) + d_array_2.debug_copy_from(3, array_4) + for op, np_op in [ # pylint: disable=invalid-name + ("sum", np.add), + ("prod", np.multiply), + ("min", np.minimum), + ("max", np.maximum), + ("avg", lambda a, b: (a + b) * 0.5), + ]: + dst_array_1 = sess.empty((3, 4), "float32") + dst_array_2 = sess.empty((5, 6), "float32") + sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True) + sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True) + result_1 = dst_array_1.debug_get_from_remote(0).numpy() + result_2 = dst_array_2.debug_get_from_remote(2).numpy() + expected_1 = np_op(array_1, array_2) + expected_2 = np_op(array_3, array_4) + np.testing.assert_equal(result_1, expected_1) + np.testing.assert_equal(result_2, expected_2) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_allgather(session_kind, ccl): @@ -101,10 +137,47 @@ def test_allgather(session_kind, ccl): ) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_allgather(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32") + array_2 = np.arange(48, dtype="float32") + d_src_1 = sess.empty((3, 3, 2), "float32") + d_dst_1 = sess.empty((3, 4, 3), "float32") + d_src_2 = sess.empty((2, 4, 3), "float32") + d_dst_2 = sess.empty((2, 6, 4), "float32") + d_src_1.debug_copy_from(0, array_1[:18]) + d_src_1.debug_copy_from(1, array_1[18:]) + d_src_2.debug_copy_from(2, array_2[:24]) + d_src_2.debug_copy_from(3, array_2[24:]) + sess.allgather(d_src_1, d_dst_1, in_group=True) + sess.allgather(d_src_2, d_dst_2, in_group=True) + np.testing.assert_equal( + d_dst_1.debug_get_from_remote(0).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst_1.debug_get_from_remote(1).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst_2.debug_get_from_remote(2).numpy(), + array_2.reshape(2, 6, 4), + ) + np.testing.assert_equal( + d_dst_2.debug_get_from_remote(3).numpy(), + array_2.reshape(2, 6, 4), + ) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) @pytest.mark.parametrize("use_explicit_output", [True, False]) -def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): +def test_broadcast(session_kind, ccl, use_explicit_output): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) @@ -123,6 +196,29 @@ def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): np.testing.assert_equal(result, array) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_broadcast(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.multiply(array_1, -1) + + src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True) + src_array.debug_copy_from(0, array_1) + src_array.debug_copy_from(2, array_2) + dst_array = sess.empty((3, 4), "float32") + sess.broadcast_from_worker0(src_array, dst_array) + + result_1 = dst_array.debug_get_from_remote(1).numpy() + np.testing.assert_equal(result_1, array_1) + + result_3 = dst_array.debug_get_from_remote(3).numpy() + np.testing.assert_equal(result_3, array_2) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) @pytest.mark.parametrize("use_explicit_output", [True, False]) @@ -156,6 +252,45 @@ def test_scatter(session_kind, ccl, use_explicit_output, capfd): ), "No warning messages should be generated from disco.Session.scatter_from_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_scatter(session_kind, ccl, capfd): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3) + array_2 = np.multiply(array_1, -1) + + d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True) + d_src.debug_copy_from(0, array_1) + d_src.debug_copy_from(2, array_2) + d_dst = sess.empty((6, 3), "float32") + sess.scatter_from_worker0(d_src, d_dst) + + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array_1[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(1).numpy(), + array_1[1, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(2).numpy(), + array_2[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(3).numpy(), + array_2[1, :, :], + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.scatter_from_worker0" + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_scatter_with_implicit_reshape(session_kind, ccl, capfd): @@ -225,6 +360,37 @@ def test_gather(session_kind, ccl, capfd): ), "No warning messages should be generated from disco.Session.gather_to_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_group_gather(session_kind, ccl, capfd): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(36, dtype="float32") + array_2 = np.multiply(array_1, -1) + d_src = sess.empty((3, 3, 2), "float32") + d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True) + d_src.debug_copy_from(0, array_1[:18]) + d_src.debug_copy_from(1, array_1[18:]) + d_src.debug_copy_from(2, array_2[:18]) + d_src.debug_copy_from(3, array_2[18:]) + sess.gather_to_worker0(d_src, d_dst) + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array_1.reshape(3, 4, 3), + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(2).numpy(), + array_2.reshape(3, 4, 3), + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.gather_to_worker0" + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index 502cbe0b811a..b4e2440857e6 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -22,6 +22,7 @@ import numpy as np import tvm +import tvm.testing from tvm import dlight as dl from tvm import relax as rx from tvm._ffi import register_func @@ -246,7 +247,7 @@ class Module: # pylint: disable=too-few-public-methods @R.function def main( loader: R.Object, - ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32"),): + ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32")): R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed( diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index ef8ea2e70a25..837b3a14f271 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -22,13 +22,14 @@ import pytest import tvm +import tvm.testing from tvm import relax as rx from tvm.runtime import ShapeTuple, String from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.testing import disco as _ +from tvm.exec import disco_worker as _ def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): @@ -168,14 +169,14 @@ class TestMod: @T.prim_func def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): for i, j in T.grid(16, 8): - with T.block("transpose"): + with T.block("t1"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @T.prim_func def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): for i, j in T.grid(8, 16): - with T.block("transpose"): + with T.block("t2"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @@ -183,7 +184,7 @@ def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): def transpose_1( A: R.Tensor((8, 16), dtype="float32") ) -> R.Tensor((16, 8), dtype="float32"): - R.func_attr({"global_symbol": "main"}) + R.func_attr({"global_symbol": "transpose_1"}) cls = TestMod with R.dataflow(): B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) @@ -194,7 +195,7 @@ def transpose_1( def transpose_2( A: R.Tensor((16, 8), dtype="float32") ) -> R.Tensor((8, 16), dtype="float32"): - R.func_attr({"global_symbol": "main"}) + R.func_attr({"global_symbol": "transpose_2"}) cls = TestMod with R.dataflow(): B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32")) @@ -228,11 +229,4 @@ def test_num_workers(session_kind, num_workers): if __name__ == "__main__": - test_int(di.ProcessSession) - test_float(di.ProcessSession) - test_string(di.ProcessSession) - test_string_obj(di.ProcessSession) - test_shape_tuple(di.ProcessSession) - test_ndarray(di.ProcessSession) - test_vm_module(di.ProcessSession) - test_vm_multi_func(di.ProcessSession) + tvm.testing.main() diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py index 3a76f535d76b..6ee64a18156d 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py @@ -220,7 +220,7 @@ def foo( out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"), ) lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce( - gv, op_type="sum" + gv, op_type="sum", in_group=False ) return lv3 @@ -1559,7 +1559,7 @@ def foo( out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.ccl.allreduce( - gv, op_type="sum" + gv, op_type="sum", in_group=False ) lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.add, diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 63563ee3c95d..9ea4d21d610d 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -40,11 +40,11 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3])], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4])], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -66,8 +66,8 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32")) - gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32")) + gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) + gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) return x # fmt: on @@ -88,7 +88,7 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", x, out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False], out_sinfo=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -134,7 +134,7 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), dtype="flo cls = Expected gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5), dtype="float32")) gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10, 5), dtype="float32")) - gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1,), out_sinfo=R.Tensor((10, 5), dtype="float32")) + gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1, False), out_sinfo=R.Tensor((10, 5), dtype="float32")) return gv0 # fmt: on From 50d1c97dc982c6ddfe089852d1fbbac3ea629851 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Tue, 23 Jul 2024 20:57:53 +0530 Subject: [PATCH 020/202] [DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [DLIGHT][GPU] Add OpenCL dequant matmul schedule 1. Enhanced the GPU matmul schedule for OpenCL Android and windows backend. 2. It improves the 2X performance gain for Llama-2-7B prefill process Model device Earlier prefill perf Optimized prefill perf Llama-2-7B-chat-hf Snapdragon® 8 Gen 3 27 tok/sec 50 tok/sec * Update matmul.py --- python/tvm/dlight/gpu/matmul.py | 144 +++++++++++++++++-- tests/python/dlight/test_gpu_matmul.py | 192 ++++++++++++++++++++----- 2 files changed, 292 insertions(+), 44 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index a5759941caf5..25cc649b44dd 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -27,7 +27,7 @@ from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV -from ..base import analysis +from ..base import analysis, BlockInfo, IterInfo from .base import GPUScheduleRule @@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: ) +def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: + def _iter_kind(loop: tir.IterVar) -> str: + return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + + def _is_reduction_block(block: tir.schedule.BlockRV): + for iter_var in sch.get(block).iter_vars: + if _iter_kind(iter_var) == "R": + return True + return False + + return BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter_var), + var=iter_var.var, + dom=iter_var.dom.extent, + loop_rv=loop_rv, + ) + for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars) + ], + block_rv=block, + reduction_block=_is_reduction_block(block), + ) + + def get_reduction_blocks(sch, blocks) -> bool: # Get the main computation block def is_reduction(block: BlockRV) -> bool: @@ -914,17 +940,19 @@ def get_configs(self, target: Target) -> Config: storage_align=True, inner_x=False, ) - elif target.kind.name == "opencl" and "android" in str(target.host): + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("windows" in str(target.host)) + ): return Matmul.Config( - block_size_x=8, - block_size_y=16, + block_size_x=32, + block_size_y=8, vthread_x=1, vthread_y=1, micro_size_x=8, micro_size_y=2, micro_size_k=16, vector_size=8, - unroll=64, + unroll=4, use_shared=False, storage_align=False, inner_x=True, @@ -941,6 +969,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) + config = self.get_configs(target) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) @@ -953,9 +982,22 @@ def apply( # pylint: disable=too-many-locals,missing-docstring index_maps = get_index_map(block_stmt) if index_maps is None: return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + main_block_info = get_block_info(sch, main_block) + iter_infos = main_block_info.iters + + # Checks if it's a inner reduction by getting the last matrix's inner Index + def is_inner_reduction(block_stmt, iter_infos): + end_it = block_stmt.reads[-1].region[-1].min + return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R" + + if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos): + ret = self.sch_outer_reduction(sch, config, main_block, blocks) + if ret is not None: + return ret # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps block = sch.reindex(main_block, ("read", 0)) sch.transform_layout(block, ("write", 0), a_index_map) block = sch.reindex(main_block, ("read", 1)) @@ -994,10 +1036,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring except: # pylint: disable=bare-except pass - # Step 2. Get schedule config. - config = self.get_configs(target) - - # Step 3. Schedule matmul + # Step 2. Schedule matmul y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x if config.inner_x: @@ -1075,3 +1114,88 @@ def _cooperative_fetch(index, vec_len): sch.decompose_reduction(main_block, ko) return sch + + def sch_outer_reduction( + self, + sch: tir.Schedule, + config: Config, + reduction_block: tir.schedule.BlockRV, + blocks: List[tir.schedule.BlockRV], + ) -> Optional[tir.Schedule]: + reduction_loops = sch.get_loops(reduction_block) + if not len(reduction_loops) == 4: + return None + + mb, ms, n, k = reduction_loops + if not ( + isinstance(sch.get(n).extent, tir.IntImm) + and isinstance(sch.get(mb).extent, tir.IntImm) + and isinstance(sch.get(ms).extent, tir.Var) + ): + return None + + Threads_X, Threads_Y, VecSize, Unroll_M = ( + config.block_size_x, + config.block_size_y, + config.vector_size, + config.unroll, + ) + + is_dequant_block = len(blocks) > 1 + if is_dequant_block: + compute_block, dequant_block, matmul_block = blocks + sch.compute_inline(compute_block) + else: + (matmul_block,) = blocks + + m = sch.fuse(mb, ms) + + sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1]) + + rmat_block, wmat_block = ( + sch.get_producers(matmul_block)[0], + sch.get_consumers(matmul_block)[0], + ) + mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M]) + no, ni, nv = sch.split(n, [None, Threads_X, VecSize]) + k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8]) + sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv) + + sch.compute_at(rmat_block, k0) + if is_dequant_block: + sch.compute_at(dequant_block, k3) + sch.reverse_compute_at(wmat_block, mi) + sch.set_scope(rmat_block, 0, "shared") + sch.set_scope(matmul_block, 0, "local") + if is_dequant_block: + sch.set_scope(dequant_block, 0, "local") + + sch.bind(mo, "blockIdx.y") + sch.bind(no, "blockIdx.x") + sch.bind(mi, "threadIdx.y") + sch.bind(ni, "threadIdx.x") + sch.vectorize(sch.get_loops(matmul_block)[-1]) + if is_dequant_block: + sch.vectorize(sch.get_loops(dequant_block)[-1]) + + # Co-operative Memory Fetch + ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize]) + sch.bind(ro, "threadIdx.x") + sch.vectorize(rv) + + wv = sch.get_loops(wmat_block)[-1] + sch.vectorize(wv) + + # Scale and Quant Cache + if is_dequant_block: + qb = sch.cache_read(dequant_block, 0, "local") + sb = sch.cache_read(dequant_block, 1, "local") + sch.compute_at(sb, k1) + sch.compute_at(qb, k2) + sch.set_scope(sb, 0, "local") + sch.set_scope(qb, 0, "local") + sch.vectorize(sch.get_loops(qb)[-1]) + sch.vectorize(sch.get_loops(sb)[-1]) + + sch.decompose_reduction(matmul_block, k0) + return sch diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index ca32c286abfe..4cef7f1c27c3 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -634,42 +634,166 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) # with T.block("root"): - matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") - for ax0_ax1_0_fused in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): - for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.y"): - for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for ax1_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): - for ax1_3_init, ax2_3_0_init in T.grid(T.int64(2), T.int64(1)): - for ax2_3_1_init in T.vectorized(T.int64(8)): - with T.block("matmul_init"): + inp0_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared") + matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") + for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"): + for i0_i1_fused_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i0_i1_fused_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.reads() + T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) + matmul_pad_local[v_i0, v_i1, v_i2] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0 in range(T.int64(4)): + for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax1_1 in T.vectorized(T.int64(8)): + with T.block("inp0_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3_init) - v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0_init * T.int64(8) + ax2_3_1_init) - T.reads() - T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) - matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0) - for ax3_0, ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(256), T.int64(16), T.int64(2), T.int64(1)): - for ax2_3_1 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) + T.reads(inp0[v0, v1, v2]) + T.writes(inp0_pad_shared[v0, v1, v2]) + inp0_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) + for k_1, k_2, k_3, i0_i1_fused_2 in T.grid(T.int64(8), T.int64(4), T.int64(8), T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) + T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_pad_shared[v_i0, v_i1, v_k], inp1[v_k, v_i2]) + T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) + matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_pad_shared[v_i0, v_i1, v_k] * inp1[v_k, v_i2] + for ax0 in range(T.int64(4)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("matmul_pad"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) + T.where((i0_i1_fused_0 - (m + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < m) + T.reads(matmul_pad_local[v0, v1, v2]) + T.writes(matmul[v0, v1, v2]) + matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2] + + +class TestFusedDequantMatmulAndroid(AndroidBeforeAfter): + # fmt: off + @T.prim_func + def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + seq_len = T.int64() + rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") + matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + # with T.block("root"): + compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(lv840[v_i0 // T.int64(8), v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) + for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): + with T.block("dequantize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1]) + T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) + dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1] + for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(rms_norm260[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) + T.writes(matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] + + @T.prim_func + def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") + matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + # with T.block("root"): + dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") + rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") + matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local") + lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") + lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): + for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for i0_i1_fused_2_init in range(T.int64(4)): + for i2_2_init in T.vectorized(T.int64(8)): + with T.block("matmul_init"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.reads() + T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0) + for k_0 in range(T.int64(16)): + for ax0 in range(T.int64(4)): + for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax1_1 in T.vectorized(T.int64(8)): + with T.block("rms_norm260_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3) - v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0 * T.int64(8) + ax2_3_1) - v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) - T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0[T.int64(0), v1, v3], inp1[v3, v2]) - T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) - matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + T.if_then_else(v1 < m, inp0[T.int64(0), v1, v3], T.float32(0)) * inp1[v3, v2] - for ax0, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(2), T.int64(1)): - for ax2_1_1 in T.vectorized(T.int64(8)): - with T.block("matmul_reindex_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1) - v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1) - T.where(ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1 < m) - T.reads(matmul_reindex_pad_local[v0, v1, v2]) - T.writes(matmul[T.int64(0), v1, v2]) - matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) + T.reads(rms_norm260[v0, v1, v2]) + T.writes(rms_norm260_pad_shared[v0, v1, v2]) + rms_norm260_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0)) + for k_1 in range(T.int64(8)): + for ax0 in T.vectorized(T.int64(8)): + with T.block("lv841_local"): + v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + T.reads(lv841[v0, v1]) + T.writes(lv841_local[v0, v1]) + lv841_local[v0, v1] = lv841[v0, v1] + for k_2 in range(T.int64(4)): + for ax0 in T.vectorized(T.int64(8)): + with T.block("lv840_local"): + v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + T.reads(lv840[v0, v1]) + T.writes(lv840_local[v0, v1]) + lv840_local[v0, v1] = lv840[v0, v1] + for k_3 in range(T.int64(8)): + for ax0 in T.vectorized(T.int64(8)): + with T.block("dequantize"): + v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) + v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + T.reads(lv840_local[v_i0 // T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1]) + T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) + dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1] + for i0_i1_fused_2 in range(T.int64(4)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) + T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm260_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) + T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] + for ax0 in range(T.int64(4)): + for ax1 in T.vectorized(T.int64(8)): + with T.block("matmul_intermediate_pad"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) + T.where((i0_i1_fused_0 - (seq_len + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len) + T.reads(matmul_intermediate_pad_local[v0, v1, v2]) + T.writes(matmul_intermediate[v0, v1, v2]) + matmul_intermediate[v0, v1, v2] = matmul_intermediate_pad_local[v0, v1, v2] # fmt: on From 7c9969bbdfc7f032f270f9f75eeb53bf6e78ff7b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 24 Jul 2024 00:33:06 +0900 Subject: [PATCH 021/202] Remove and replace deprecated `distutils.util.strtobool()` (#17185) remove and replace deprecated distutils.util.strtobool --- python/tvm/auto_scheduler/testing/tune_onnx.py | 2 +- python/tvm/auto_scheduler/testing/tune_relay.py | 2 +- python/tvm/auto_scheduler/testing/tune_te.py | 2 +- python/tvm/autotvm/testing/tune_relay.py | 2 +- python/tvm/meta_schedule/testing/tune_onnx.py | 2 +- python/tvm/meta_schedule/testing/tune_relay.py | 2 +- python/tvm/meta_schedule/testing/tune_te.py | 2 +- .../meta_schedule/testing/validate_database.py | 2 +- python/tvm/testing/utils.py | 15 +++++++++++++++ 9 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/tvm/auto_scheduler/testing/tune_onnx.py b/python/tvm/auto_scheduler/testing/tune_onnx.py index a3299c05bb82..334b5d6726b7 100644 --- a/python/tvm/auto_scheduler/testing/tune_onnx.py +++ b/python/tvm/auto_scheduler/testing/tune_onnx.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import json import os @@ -30,6 +29,7 @@ from tvm.meta_schedule.utils import cpu_count from tvm.relay.frontend import from_onnx from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py index 9773fbbc65ad..babec2cf50c4 100644 --- a/python/tvm/auto_scheduler/testing/tune_relay.py +++ b/python/tvm/auto_scheduler/testing/tune_relay.py @@ -18,7 +18,6 @@ import argparse import json import os -from distutils.util import strtobool import tvm from tvm import auto_scheduler @@ -29,6 +28,7 @@ from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data from tvm.meta_schedule.utils import cpu_count from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/auto_scheduler/testing/tune_te.py b/python/tvm/auto_scheduler/testing/tune_te.py index da3584512dd0..9452d88a4e65 100644 --- a/python/tvm/auto_scheduler/testing/tune_te.py +++ b/python/tvm/auto_scheduler/testing/tune_te.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import os @@ -25,6 +24,7 @@ from tvm.meta_schedule.testing.te_workload import CONFIGS from tvm.meta_schedule.utils import cpu_count from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/autotvm/testing/tune_relay.py b/python/tvm/autotvm/testing/tune_relay.py index 96e42fbea090..916b2a800b2d 100644 --- a/python/tvm/autotvm/testing/tune_relay.py +++ b/python/tvm/autotvm/testing/tune_relay.py @@ -19,7 +19,6 @@ import json import os import warnings -from distutils.util import strtobool import tvm from tvm import autotvm @@ -31,6 +30,7 @@ from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py b/python/tvm/meta_schedule/testing/tune_onnx.py index a7c177afdca4..2100f0e7c973 100644 --- a/python/tvm/meta_schedule/testing/tune_onnx.py +++ b/python/tvm/meta_schedule/testing/tune_onnx.py @@ -18,7 +18,6 @@ import argparse import json import logging -from distutils.util import strtobool import onnx # type: ignore import tvm @@ -26,6 +25,7 @@ from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.relay.frontend import from_onnx from tvm.support import describe +from tvm.testing.utils import strtobool from .tune_utils import create_timer, generate_input_data diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py index de1668c1dd16..98eddf793fce 100644 --- a/python/tvm/meta_schedule/testing/tune_relay.py +++ b/python/tvm/meta_schedule/testing/tune_relay.py @@ -18,7 +18,6 @@ import argparse import json import logging -from distutils.util import strtobool from typing import Dict import numpy as np # type: ignore @@ -28,6 +27,7 @@ from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/meta_schedule/testing/tune_te.py b/python/tvm/meta_schedule/testing/tune_te.py index 4bbfd8b1517e..de80d7108d7f 100644 --- a/python/tvm/meta_schedule/testing/tune_te.py +++ b/python/tvm/meta_schedule/testing/tune_te.py @@ -17,7 +17,6 @@ # pylint: disable=missing-docstring import argparse import logging -from distutils.util import strtobool from typing import Optional import tvm @@ -25,6 +24,7 @@ from tvm import tir from tvm.meta_schedule.testing.te_workload import create_te_workload from tvm.support import describe +from tvm.testing.utils import strtobool def _parse_args(): diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index a5981a78d645..a790bb49f73e 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -20,7 +20,6 @@ import warnings import itertools from statistics import mean -from distutils.util import strtobool from typing import Callable, Tuple, Union, List, Any import numpy as np # type: ignore @@ -35,6 +34,7 @@ from tvm.meta_schedule.utils import remove_build_dir from tvm.meta_schedule.testing.tune_utils import generate_input_data from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=wildcard-import,unused-wildcard-import +from tvm.testing.utils import strtobool DELIMITOR = "\n" + "-" * 30 + "\n" diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 8fd64d8ab749..64eaccb410c8 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1913,6 +1913,21 @@ def skip_parameterizations(*skip_params, reason): return _mark_parameterizations(*skip_params, marker_fn=pytest.skip, reason=reason) +def strtobool(val): + """Convert a string representation of truth to true (1) or false (0). + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif val in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {val!r}") + + def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) From 89b91e2b1195b53bf7e1f6c250bc9a1247367d13 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 23 Jul 2024 21:13:41 -0700 Subject: [PATCH 022/202] [KVCache] Partial layers support (#17192) This PR updates the KVCache implementation, to support partial layers. --- include/tvm/runtime/disco/disco_worker.h | 15 ++++ src/runtime/disco/disco_worker.cc | 9 -- src/runtime/relax_vm/paged_kv_cache.cc | 82 +++++++++++++------ ...tin_paged_attention_kv_cache_flashinfer.py | 2 +- ...me_builtin_paged_attention_kv_cache_tir.py | 2 +- 5 files changed, 73 insertions(+), 37 deletions(-) diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 301b5b8d626b..13f94802c886 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -93,6 +93,21 @@ class DiscoWorker { struct Impl; friend struct DiscoWorker::Impl; }; +/*! + * \brief A threadlocal wrapper of DiscoWorker. + */ +struct ThreadLocalDiscoWorker { + /*! \brief The Disco worker */ + DiscoWorker* worker; + + /*! + * \brief Get the threadlocal Disco worker. + */ + static ThreadLocalDiscoWorker* Get() { + thread_local static ThreadLocalDiscoWorker worker; + return &worker; + } +}; } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index b281a3aca7da..5e6f401054ea 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -28,15 +28,6 @@ namespace tvm { namespace runtime { -struct ThreadLocalDiscoWorker { - DiscoWorker* worker; - - static ThreadLocalDiscoWorker* Get() { - thread_local static ThreadLocalDiscoWorker worker; - return &worker; - } -}; - TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker; CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread"; diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index ec1cc3593a53..2fb8a72f4279 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -21,6 +21,7 @@ * \brief Runtime paged KV cache object for language models. */ #include +#include #include #include #include @@ -825,6 +826,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t page_size_; /*! \brief The number of layers in the model. */ const int64_t num_layers_; + /*! \brief The beginning layer id offset. */ + const int64_t layer_id_begin_offset_; /*! \brief The number of query/output heads in the model. */ const int64_t num_qo_heads_; /*! \brief The number of key/value heads in the model. */ @@ -981,14 +984,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ explicit PagedAttentionKVCacheObj( - int64_t page_size, // - int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, - bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, - DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, - PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, + int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // + int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, + int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device, + PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, + PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, + PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, + PackedFunc f_attention_prefill_with_tree_mask, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, @@ -998,6 +1001,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), + layer_id_begin_offset_(layer_id_begin_offset), num_qo_heads_(num_qo_heads), num_kv_heads_(num_kv_heads), head_dim_(head_dim), @@ -1672,7 +1676,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, NDArray o_data, double attn_score_scaling_factor) final { // Part 1. Shape and dtype check. - NDArray pages = pages_[layer_id]; + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); @@ -1713,13 +1720,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { - f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } // Part 4: perform attention AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor); // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { - f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } } @@ -2238,6 +2245,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, NDArray output, double attn_score_scaling_factor) { + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); PackedFunc f_prefill = !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; PackedFunc f_decode = @@ -2245,7 +2255,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; if (append_before_attn_) { f_decode( - /*depth=*/0, q_data, pages_[layer_id], page_indptr_on_depths_view_[0], + /*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0], page_indices_on_depths_view_[0], length_info_on_depths_view_[0], k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, @@ -2280,7 +2290,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (use_decode_kernel_[d]) { // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[layer_id], page_indptr_on_depths_view_[d], + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, @@ -2289,7 +2299,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { // Use prefill kernel for depth d f_prefill( - /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[layer_id], + /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, @@ -2436,7 +2446,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; - int64_t num_layers = args[1]; + ShapeTuple layer_indptr_tuple = args[1]; + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; int64_t num_qo_heads = args[2]; int64_t num_kv_heads = args[3]; int64_t head_dim = args[4]; @@ -2482,11 +2502,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") num_total_pages += reserved_num_seqs * 2; } ObjectPtr n = make_object( - page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, - num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), - rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), + page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, + reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, + RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, + std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), + std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), @@ -2503,7 +2523,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; - int64_t num_layers = args[1]; + ShapeTuple layer_indptr_tuple = args[1]; + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; int64_t num_qo_heads = args[2]; int64_t num_kv_heads = args[3]; int64_t head_dim = args[4]; @@ -2543,11 +2573,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") num_total_pages += reserved_num_seqs * 2; } ObjectPtr n = make_object( - page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, - num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), - rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), + page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, + reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, + RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, + std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), + std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 048cf498067b..bade04a7d753 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -354,7 +354,7 @@ def create_kv_cache(rope_mode): support_sliding_window, ] ), - num_layers, + tvm.runtime.ShapeTuple([0, num_layers]), num_qo_heads, num_kv_heads, head_dim, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 34680160c8de..9192bb901ff0 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -153,7 +153,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): int(support_sliding_window), ] ), - num_layers, + tvm.runtime.ShapeTuple([0, num_layers]), num_qo_heads, num_kv_heads, head_dim, From 9a07870b2e6480a533dbebe8d10e945fc173cf59 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 24 Jul 2024 11:41:12 +0530 Subject: [PATCH 023/202] [CLML][CI] Fix for few clml regression issues (#17117) * Few regresion fixes * dummy commit * Update clml.py * Update task_python_adreno.sh * Update task_python_adreno.sh * dummy commit --------- Co-authored-by: Krishna Raju Vegiraju --- python/tvm/relay/op/contrib/clml.py | 8 ++++---- tests/scripts/setup-adreno-env.sh | 1 + tests/scripts/task_python_adreno.sh | 3 +-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 22a7aae2b165..dace7aaab913 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, pointless-exception-statement +# pylint: disable=invalid-name, unused-argument, pointless-exception-statement. """CLML Library supported operators.""" import json from string import Template @@ -166,7 +166,7 @@ def partition_for_clml(mod, params=None, **opts): transform.FoldConstant(), OptimizeBatchnormPass(), transform.MergeComposite(clml_pattern_table()), - transform.AnnotateTarget("clml", False), + transform.AnnotateTarget("clml"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ] @@ -518,7 +518,7 @@ def check_dense1d_op(extract): return False if not (call.op.name in ["nn.bias_add", "add"] and call.args[0].op.name == "nn.dense"): return False - return check_default_op(call) + return True def check_dense2d_op(extract): call = extract @@ -564,7 +564,7 @@ def check_depth_to_space(extract): ("clml.dense2d", dense2d_pattern(), check_dense2d_op), ("clml.pad", pad_pattern(), check_pad_op), ("clml.concat", concat_pattern(), check_concat_op), - ("clml.batch_norm", batch_norm_pattern(), check_default_op), + ("clml.batch_norm", batch_norm_pattern()), ("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op), ("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op), ("clml.multiply", is_op("multiply")(wildcard(), wildcard()), check_binary_op), diff --git a/tests/scripts/setup-adreno-env.sh b/tests/scripts/setup-adreno-env.sh index d2c776412e5f..cfe174214c72 100755 --- a/tests/scripts/setup-adreno-env.sh +++ b/tests/scripts/setup-adreno-env.sh @@ -80,6 +80,7 @@ function def_environment() { export RPC_DEVICE_KEY="android" export RPC_TARGET="adreno" export TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" + # Compiler definition for c-runtime while empty mod (llvm -mtriple ineffective here). export CXX="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" } diff --git a/tests/scripts/task_python_adreno.sh b/tests/scripts/task_python_adreno.sh index 18e0feb815d1..b889fd64632d 100755 --- a/tests/scripts/task_python_adreno.sh +++ b/tests/scripts/task_python_adreno.sh @@ -31,7 +31,6 @@ export TVM_TRACKER_PORT=$(((RANDOM % 100) + 9100)) export RPC_DEVICE_KEY="android" export RPC_TARGET="adreno" export TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" -export CXX="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" & TRACKER_PID=$! @@ -79,7 +78,7 @@ CLML_TESTS=$(./ci/scripts/jenkins/pytest_ids.py --folder tests/python/contrib/te i=0 for node_id in $CLML_TESTS; do echo "$node_id" - run_pytest ctypes "$TVM_INTEGRATION_TESTSUITE_NAME-openclml-$i" "$node_id" --reruns=0 + CXX=${TVM_NDK_CC} run_pytest ctypes "$TVM_INTEGRATION_TESTSUITE_NAME-openclml-$i" "$node_id" --reruns=0 i=$((i+1)) done From ae1be53d6dc08ad8a95ddf6af022880e836e8704 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 24 Jul 2024 08:03:21 -0400 Subject: [PATCH 024/202] [Disco] Cross-group and p2p send/receive primitives (#17191) This PR introduces the disco CCL primitives for cross-group and p2p communication. Specifically, we introduce the send/receive primitives for one group to send a buffer to its next group, where every worker in the first group sends the buffer to the corresponding worker in the second group. The p2p communication refer to the send/receive operations to/from a target global worker. --- include/tvm/runtime/disco/builtin.h | 24 ++++++++ python/tvm/relax/frontend/nn/core.py | 6 +- src/runtime/disco/builtin.cc | 16 ++++++ src/runtime/disco/nccl/nccl.cc | 86 ++++++++++++++++++++++++++++ tests/python/disco/test_ccl.py | 40 ++++++++++++- 5 files changed, 168 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index 7d15e35fbdbc..4453d9737f89 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional recv * \param buffer The buffer to be received */ TVM_DLL void RecvFromWorker0(NDArray buffer); +/*! + * \brief Send a buffer to the corresponding worker in the next group. + * An error is thrown if the worker is already in the last group. + * \param buffer The sending buffer. + */ +TVM_DLL void SendToNextGroup(NDArray buffer); +/*! + * \brief Receive a buffer from the corresponding worker in the previous group. + * An error is thrown if the worker is already in the first group. + * \param buffer The receiving buffer. + */ +TVM_DLL void RecvFromPrevGroup(NDArray buffer); +/*! + * \brief Send a buffer to the target receiver worker (globally across all groups). + * \param buffer The sending buffer. + * \param receiver_id The global receiver worker id. + */ +TVM_DLL void SendToWorker(NDArray buffer, int receiver_id); +/*! + * \brief Receive a buffer from the target sender worker (globally across all groups). + * \param buffer The receiving buffer. + * \param sender_id The global sender worker id. + */ +TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id); /*! \brief Get the local worker id */ TVM_DLL int WorkerId(); /*! diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 46e016a242ea..3511c38a2b7c 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -549,16 +549,16 @@ def __init__(self, modules: List[Module]): def __iter__(self): return iter(self.modules) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Module: return self.modules[idx] - def __setitem__(self, idx, module): + def __setitem__(self, idx: int, module: Module) -> None: self.modules[idx] = module def __len__(self): return len(self.modules) - def append(self, module): + def append(self, module: Module): """Add a module to the end of the ModuleList""" self.modules.append(module) diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 0cb2ee6f5d6b..760a330a7a8e 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, Optional recv) { void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); } +void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); } + +void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); } + +void SendToWorker(NDArray buffer, int receiver_id) { + GetCCLFunc("send_to_worker")(buffer, receiver_id); +} + +void RecvFromWorker(NDArray buffer, int sender_id) { + GetCCLFunc("recv_from_worker")(buffer, sender_id); +} + int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; } void SyncWorker() { @@ -136,6 +148,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); +TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); +TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); +TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); +TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple { return ShapeTuple({WorkerId()}); }); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 2d2c528b5291..35e8fd06b309 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) { NCCL_CALL(ncclGroupEnd()); } +void SendToNextGroup(NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int receiver_id = worker_id + group_size; + CHECK_LT(receiver_id, ctx->worker->num_workers) + << "The current group is already the last group and there is no such a next group."; + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + receiver_id, ctx->global_comm, stream)); + NCCL_CALL(ncclGroupEnd()); +} + +void RecvFromPrevGroup(NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int sender_id = worker_id - group_size; + CHECK_GE(sender_id, 0) + << "The current group is already the first group and there is no such a previous group."; + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + sender_id, ctx->global_comm, stream)); + NCCL_CALL(ncclGroupEnd()); +} + +void SendToWorker(NDArray buffer, int receiver_id) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers) + << "Invalid receiver id " << receiver_id << ". The world size is " + << ctx->worker->num_workers; + CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; + NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + receiver_id, ctx->global_comm, stream)); +} + +void RecvFromWorker(NDArray buffer, int sender_id) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) + << "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers; + CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + sender_id, ctx->global_comm, stream)); +} + void SyncWorker() { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ICHECK(ctx->worker != nullptr); @@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") .set_body_typed(GatherToWorker0); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") .set_body_typed(RecvFromWorker0); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") + .set_body_typed(SendToNextGroup); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") + .set_body_typed(RecvFromPrevGroup); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") + .set_body_typed(SendToWorker); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") + .set_body_typed(RecvFromWorker); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME + ".test_send_to_next_group_recv_from_prev_group") + .set_body_typed([](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int group_id = ctx->worker->worker_id / group_size; + if (group_id == 0) { + tvm::runtime::nccl::SendToNextGroup(buffer); + } else { + tvm::runtime::nccl::RecvFromPrevGroup(buffer); + } + }); + +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") + .set_body_typed([](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + if (ctx->worker->worker_id == 2) { + tvm::runtime::nccl::SendToWorker(buffer, 0); + } else if (ctx->worker->worker_id == 0) { + tvm::runtime::nccl::RecvFromWorker(buffer, 2); + } + }); + } // namespace nccl } // namespace runtime } // namespace tvm diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 6c63f64554a3..c29ece957245 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -25,11 +25,11 @@ import tvm import tvm.testing from tvm import dlight as dl +from tvm import get_global_func from tvm import relax as rx from tvm.runtime import disco as di from tvm.runtime.relax_vm import VirtualMachine from tvm.script import relax as R -from tvm import get_global_func _all_session_kinds = [di.ThreadedSession, di.ProcessSession] _ccl = [get_global_func("runtime.disco.compiled_ccl")()] @@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd): ), "No warning messages should be generated from disco.Session.gather_to_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_send_to_next_group_receive_from_prev_group(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + d_array = sess.empty((3, 4), "float32") + d_array.debug_copy_from(0, array_1) + d_array.debug_copy_from(1, array_2) + sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")( + d_array + ) + + result_1 = d_array.debug_get_from_remote(2).numpy() + result_2 = d_array.debug_get_from_remote(3).numpy() + np.testing.assert_equal(result_1, array_1) + np.testing.assert_equal(result_2, array_2) + + +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_worker2_send_to_worker0(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + d_array = sess.empty((3, 4), "float32") + d_array.debug_copy_from(2, array) + sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array) + + result = d_array.debug_get_from_remote(0).numpy() + np.testing.assert_equal(result, array) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals From 9f0f301c6f6de7548c6b2026bcb51590e0881ac5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 Jul 2024 08:24:15 -0500 Subject: [PATCH 025/202] [TIR][Analyzer] Simplify `x==x` expressions for all dtypes (#17158) * [TIR][Analyzer] Simplify `x==x` expressions for all dtypes Prior to this commit, there was no rule to simplify `x == x` into `True`. In some cases, despite not having an explicit rewrite rule in `RewriteSimplifier`, the `RewriteSimplifier::CanProve` function would check if `x-x` simplifies to zero, relying on the rewrite rules used for `tir::Sub`. However, the rule to rewrite `x-x` into zero was only enabled for `int32`, `int64`, and floating-point types, so relying on this behavior was inconsistent. This commit updates the rewrite rules for both `tir::EQ` and `tir::Sub` to check for simplification of `x-x` or `x==x`, regardless of the datatype. This change preserves the fast-path for index data-types, in which `int32` and `int64` expressions may be simplified without checking for side effects. For all other dtypes, the cancellation only applies when evaluating `x` has no side effects. * Add comment about simplifications of NaN/Inf --- src/arith/rewrite_simplify.cc | 21 ++++++++++- .../arith/test_arith_rewrite_simplify.py | 36 +++++++++++++++++++ tests/python/arith/test_arith_simplify.py | 29 +++++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f4d4a9048ced..3682054e8e4b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; + // Vector rules if (op->dtype.is_scalable_or_fixed_length_vector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); @@ -697,9 +698,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1)); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); - } else if (op->dtype.is_float()) { + } else { // Cancellation rules. Deliberately off of the integer path, to // avoid introducing checks on the side effects for the fast path. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), SideEffect(x.Eval()) <= CallEffectKind::kReadState); TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState); @@ -1678,6 +1685,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; + PConst ctrue(make_const(ret->dtype, true)); // vector rule if (ret->dtype.is_scalable_or_fixed_length_vector()) { @@ -1698,6 +1706,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2); TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); + TVM_TRY_REWRITE(x == x, ctrue); + } else { + // Mimic the cancellation rules for SubNode. For Index datatypes, + // we skip the check for side effects. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. + TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } return std::move(ret); } diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 1ebaab53af2d..90f0aeef47d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -321,6 +321,42 @@ class TestSelect(BaseCompare): ) +class TestCancellation(BaseCompare): + var_int8 = tir.Var("var_int8", "int8") + var_int32 = tir.Var("var_int32", "int32") + var_int64 = tir.Var("var_int64", "int64") + var_uint8 = tir.Var("var_uint8", "uint8") + var_uint32 = tir.Var("var_uint32", "uint32") + var_uint64 = tir.Var("var_uint64", "uint64") + + test_case = tvm.testing.parameter( + TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, "int64")), + TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, "uint8")), + TestCase(var_int8 - var_int8, tir.const(0, "int8")), + TestCase(var_int32 - var_int32, tir.const(0, "int32")), + TestCase(var_int64 - var_int64, tir.const(0, "int64")), + TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")), + TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")), + TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")), + TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(True, "bool")), + TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(True, "bool")), + TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")), + TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")), + TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")), + TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(False, "bool")), + TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(False, "bool")), + TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")), + TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")), + TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")), + TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")), + TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")), + TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")), + ) + + class TestAddIndex(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 9a0245d27487..3b0237740045 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index(): ) +dtype = tvm.testing.parameter( + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", +) + + +def test_can_prove_self_identity(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove(n == n) + + +def test_can_prove_self_equal_to_self(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove_equal(n, n) + + def test_simplify_symbolic_comparison(): ana = tvm.arith.Analyzer() From cc8afdb0e3be52a3aa162ff14a81b11a793dca6b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 24 Jul 2024 22:36:19 +0900 Subject: [PATCH 026/202] Add support for `torch.nn.functional.max_pool2d` (#17189) * add a testcase for call_function * add maxpool2d to call_function --- python/tvm/relax/frontend/torch/fx_translator.py | 1 + tests/python/relax/test_frontend_from_fx.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e6b39c3eee0e..093f3ae4cf7a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1476,6 +1476,7 @@ def create_convert_map(self): "getitem": self._getitem, "contiguous": lambda node: self.env[node.args[0]], "to": self._to, + "max_pool2d": self._max_pool2d, "avg_pool2d": self._avg_pool2d, "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), "layer_norm": self._layer_norm, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b4ac3fa60ce9..1a2cc5da6242 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -796,6 +796,13 @@ def __init__(self): def forward(self, input): return self.pool(input) + class MaxPool2d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1]) + @tvm.script.ir_module class expected1: @R.function @@ -876,6 +883,7 @@ def main( return gv verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d_functional(), input_info, {}, expected1) verify_model(MaxPool2d2(), input_info, {}, expected2) verify_model(MaxPool2d3(), input_info, {}, expected3) From 7bd738a00b08ee5cd89623075f2f692c246881fd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 24 Jul 2024 10:42:02 -0500 Subject: [PATCH 027/202] [Relax] Implement Rewriter class for pattern-rewrite (#17149) * [TVMScript][Bugfix] Normalize relax::If with function's TIR var Prior to this commit, the branches of `relax::If` were normalized using `EraseToWellDefinedInScope`, using a fresh variable scope. While this had the intended behavior of preventing variables defined in a single branch from being usable outside of the conditional, it also caused the conditional's branches to treat function-scope symbolic variables as if they were undefined. This commit updates the `tvm::relax::Normalizer` so that `relax::If` is normalized within an inherited scope. This preserves the previous behavior for symbolic variables defined within a branch, but allows shapes within a branch to use symbolic variables defined outside of the branch. * [Relax] Canonicalize known symbolic shapes in Relax expressions Prior to this commit, known constants in Relax functions would be inlined by the `CanonicalizeBindings` pass, but only if they appeared as Relax expressions (e.g. `R.const` or `R.prim_value`). Known constants that appeared as TIR variables (e.g. symbolic shapes) would be kept as dynamic parameters, even if they were known at compile time. This commit updates the `CanonicalizeBindings` pass to identify known values of symbolic shapes, and to use these known values in shape expressions. * [Relax][Refactor] Reorganize pattern-matching A follow-up to https://github.com/apache/tvm/pull/16730. Now that the implementations for `rewrite_call` and `rewrite_bindings` are in separate classes, they can be further split out into separate files. * [Relax][Refactor] Implement Rewriter class for pattern-rewrite Prior to this commit, the pattern to be matched and the rewrite to be performed were provided as separate arguments. This commit introduces a new class `ExprRewriter`, which contains both parts. This abstraction will make it easier to combine multiple different rewrite rules, applying them in a single pass. * lint fixes * Remove unnecessary change which broke a unit test * lint fix for import order * Add docstrings * lint fix * Lint fix * lint fixes * lint fix * Update based on review comments * Add test case for matching against arbitrary dtype * Fix breakage in unit tests One unit test that had been relying on invalid shape propagation. Another unit test that required constructed an ill-formed output to test against. * Updated base class name from ExprRewriter to PatternMatchingRewriter * lint fix --- include/tvm/relax/block_builder.h | 35 +- include/tvm/relax/expr_functor.h | 21 +- include/tvm/script/ir_builder/relax/frame.h | 1 + python/tvm/relax/dpl/__init__.py | 8 +- python/tvm/relax/dpl/rewrite.py | 186 +- python/tvm/script/ir_builder/relax/ir.py | 48 +- python/tvm/script/parser/core/utils.py | 14 +- src/relax/ir/block_builder.cc | 95 +- src/relax/ir/dataflow_block_rewriter.cc | 452 +++++ src/relax/ir/dataflow_expr_rewriter.cc | 1079 ++++++++++++ src/relax/ir/dataflow_matcher.cc | 669 +------- ...flow_matcher_impl.h => dataflow_matcher.h} | 15 +- src/relax/ir/dataflow_rewriter.h | 182 ++ src/relax/ir/expr.cc | 42 +- src/relax/ir/expr_functor.cc | 54 +- src/relax/transform/canonicalize_bindings.cc | 142 +- src/relax/transform/utils.h | 2 +- src/relax/utils.cc | 16 +- src/script/ir_builder/relax/frame.cc | 7 +- src/script/ir_builder/relax/ir.cc | 10 +- tests/python/relax/test_dataflow_rewriter.py | 1512 +++++++++++++++++ .../test_transform_canonicalize_bindings.py | 255 ++- .../test_transform_legalize_ops_manipulate.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 46 + 24 files changed, 4142 insertions(+), 751 deletions(-) create mode 100644 src/relax/ir/dataflow_block_rewriter.cc create mode 100644 src/relax/ir/dataflow_expr_rewriter.cc rename src/relax/ir/{dataflow_matcher_impl.h => dataflow_matcher.h} (91%) create mode 100644 src/relax/ir/dataflow_rewriter.h create mode 100644 tests/python/relax/test_dataflow_rewriter.py diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 7ca9aab6d5aa..ad2b9820707a 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -133,16 +133,47 @@ class BlockBuilderNode : public Object { * \brief Begin a new scope, with optional parameters that * are visible within the scope. * + * Symbolic variables from the parent scope are not available. + * * \param params Parameters that are visible within the scope. * * \note This function should be called when new scope is introduced - * (function, seq) to properly track the variable availability - * and help the best effort deduction. + * (e.g. function bodies) to properly track the variable + * availability and help the best effort deduction. * * \sa EndScope */ virtual void BeginScope(Optional> params) = 0; + /*! + * \brief Begin a new scope, which inherits visible parameters from + * its parent scope. + * + * Symbolic variables from the parent scope are available. + * + * \note This function should be called when an inner scope is + * introduced (e.g. conditional branches) to properly track + * the variable availability and help the best effort + * deduction. + * + * \sa EndScope + */ + virtual void BeginInnerScope() = 0; + + /*! + * \brief Append a definition to the current scope. + * + * \param var A variable within the current scope. + * + * \note This function should be called when a new variable is + * defined that may impact struct inference (e.g. MatchCast) + * to properly track the variable availability and help the + * best effort deduction. + * + * \sa EndScope + */ + virtual void AddDefinitionToScope(Var var) = 0; + /*! \brief End the previously defined scope. */ virtual void EndScope() = 0; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ce209ccd460f..c3aea24dcb50 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase { void ReEmitBinding(const VarBindingNode* binding, Expr new_value); /*! - * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * \brief Rewrite the expr with a new scope, used in a Function's body. + * + * Visit an expression that may neither access variables from the + * current scope, nor may export definitions into the current scope. * * \param body_expr The body to be visited. * \param params Optional parameters that are visible within the scope. @@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase { */ Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + /*! + * \brief Rewrite the expr with a new scope, used in the branches of If. + * + * Visit an expression that may access variables from the current + * scope, but may not export definitions into the current scope. + * + * \param body_expr The body to be visited. + * + * \return The expr after visiting. + * + * \sa VisitWithNewScope + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithInnerScope(const Expr& body_expr); + /*! * \brief Look up the value bound to a variable. * \param var The var to be looked up. diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 1ad681388912..0ee144f03e77 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); public: + void EnterWithScope() final; void ExitWithScope() final; }; diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index 6451238428c2..a4f3f4063e90 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,4 +19,10 @@ from .pattern import * from .context import * -from .rewrite import rewrite_call, rewrite_bindings +from .rewrite import ( + rewrite_call, + rewrite_bindings, + PatternMatchingRewriter, + ExprPatternRewriter, + OrRewriter, +) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 291061090fc2..96c69e9266a2 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -15,16 +15,196 @@ # specific language governing permissions and limitations # under the License. """APIs for pattern-based rewriting.""" -from typing import Dict, Callable + +from typing import Dict, Callable, Union + +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm._ffi import register_object + from .pattern import DFPattern from .context import PatternContext - from ..expr import Expr, Function, Var from . import _ffi as ffi +@register_object("relax.dpl.PatternMatchingRewriter") +class PatternMatchingRewriter(Object): + """A pattern-matching rewriter for Relax""" + + @staticmethod + def from_pattern( + pattern: DFPattern, + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + ) -> "PatternMatchingRewriter": + """Construct from a pattern and rewriter-function + + The replacements performed by the rewriter will be equivalent + to using the `pattern` and `func` as arguments to + `rewrite_call`. + + Parameters + ---------- + pattern: DFPattern + + The pattern to be matched against. + + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + + A function that returns the rewritten expression. See + `rewrite_call` for details and examples. + + + Returns + ------- + rewriter_obj: PatternMatchingRewriter + + The rewriter object + + """ + return ffi.PatternMatchingRewriterFromPattern( + pattern, + func, + ) # type: ignore + + @staticmethod + def from_module(mod: IRModule) -> "PatternMatchingRewriter": + """Construct a rewriter from an IRModule + + The IRModule must have two publicly-exposed functions, + `pattern` and `replacement`, where `pattern` and `replacement` + have the same function signature, as shown in the example + below. + + .. code-block:: python + + @I.ir_module + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply) + rewritten_ir_module = rewriter(ir_module) + + To support the common case of defining an IRModule with + TVMScript, then immediately turning it into a rewriter, the + `@R.rewriter` annotation can be used. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewritten_ir_module = RewriteAddIntoMultiply(ir_module) + + Parameters + ---------- + mod: IRModule + + A module with `pattern` and `replacement` functions, + defining a rewrite rule. + + + Returns + ------- + rewriter_obj: PatternMatchingRewriter + + The rewriter object + + """ + return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore + + def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: + """Apply the rewriter + + Parameters + ---------- + obj: Union[Expr, IRModule]) + + The object to be rewritten. May be applied to either a + relax expression, or an IRModule. + + Returns + ------- + updated: Union[Expr, IRModule] + + The rewritten object + + """ + return ffi.PatternMatchingRewriterApply(self, obj) + + def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter": + """Compose two rewriters + + Composing two rewrite rules together allows them to be applied + in a single Relax-level transformation. + + Parameters + ---------- + other: PatternMatchingRewriter + + Another rewrite rule + + Returns + ------- + PatternMatchingRewriter + + A rewriter that will apply either rewrite pattern + + """ + return OrRewriter(self, other) + + +@register_object("relax.dpl.ExprPatternRewriter") +class ExprPatternRewriter(PatternMatchingRewriter): + def __init__(self, pattern, func): + self.__init_handle_by_constructor__( + ffi.PatternRewriter, + pattern, + func, + ) # type: ignore + + +@register_object("relax.dpl.OrRewriter") +class OrRewriter(PatternMatchingRewriter): + def __init__(self, lhs, rhs): + self.__init_handle_by_constructor__( + ffi.OrRewriter, + lhs, + rhs, + ) # type: ignore + + +@register_object("relax.dpl.TupleRewriter") +class TupleRewriter(PatternMatchingRewriter): + def __init__(self, patterns, func): + self.__init_handle_by_constructor__( + ffi.TupleRewriter, + patterns, + func, + ) # type: ignore + + def rewrite_call( - pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function + pattern: DFPattern, + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + func: Function, ) -> Function: """ Rewrite a function with the given pattern and the rewriter function. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ef9ae775450b..c4be8afac4d2 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,11 +20,11 @@ import builtins import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import tvm from tvm import DataType, relax -from tvm.ir import PrimExpr, VDevice +from tvm.ir import PrimExpr, VDevice, IRModule from tvm.relax import ( Call, Expr, @@ -35,6 +35,7 @@ VarBinding, const, ) +from tvm.relax.dpl import PatternMatchingRewriter ############################### Operators ############################### from tvm.relax.op import ( @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None: return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member +def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter: + """Define a pattern-rewrite rule + + The IRModule must have two publicly-exposed functions, `pattern` + and `replacement`, where `pattern` and `replacement` have the same + function signature. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + Parameters + ---------- + rewriter_mod: Union[IRModule, Type] + + Either an IRModule that defines a rewrite pattern, or a + TVMScript class that can be parsed into an IRModule. + + Returns + ------- + rewriter: PatternMatchingRewriter + + A rewriter object, which can be applied either to a Relax + function or to an entire IRModule. + + """ + if not isinstance(rewriter_mod, IRModule): + rewriter_mod = tvm.script.ir_module(rewriter_mod) + + return PatternMatchingRewriter.from_module(rewriter_mod) + + ############################# BindingBlock ############################## @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "dequantize", "repeat", "reshape", + "rewriter", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 3edae3f25a33..8ad64f5dbc68 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -100,19 +100,29 @@ def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool: res : bool The result if the object is defined in a class scope. """ + + def _is_tvmscript_class_annotator(line: str) -> bool: + """Checks if the line contains a TVMScript annotator for a class + + These match either `@I.ir_module` or `@R.rewriter`, or their + imported names `@ir_module` or `@rewriter`. + """ + + return line.startswith("@") and ("ir_module" in line or "rewriter" in line) + if len(frames) > 2: frame_info = frames[2] code_context = frame_info.code_context if code_context is None: return False line = code_context[0].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True if line.startswith("class"): lineno = frame_info.lineno if lineno >= 2: source, _ = findsource(obj) line = source[lineno - 2].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True return False diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index f6aec79a4ac4..b8092bbf3a4d 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -178,29 +178,54 @@ class BlockBuilderImpl : public BlockBuilderNode { // but can be further improved. // // TODO(relax-team): Add support for relax Var in struct info annotations. - Map shape_var_map; - for (const Var& var : params.value_or(Array())) { - const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); - for (const auto& kv : var_map) { - const tir::Var& shape_var = kv.first; - const PrimExpr& shape_expr = kv.second; - auto it = shape_var_map.find(shape_var); - if (it == shape_var_map.end()) { - shape_var_map.Set(shape_var, shape_expr); - // Expose the shape variable as non-negative, for purposes - // of shape inference. In many cases, knowning that the - // shape variable is non-negative allows for simpler - // expressions for dynamic shapes. - analyzer_.MarkGlobalNonNegValue(shape_var); - } else { - const PrimExpr& old_shape_expr = (*it).second; - CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) - << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " - << shape_expr; - } + + scope_stack_.emplace_back(ScopeFrame()); + if (params.defined()) { + for (const auto& param : params.value()) { + AddDefinitionToScope(param); + } + } + } + + void BeginInnerScope() final { + if (scope_stack_.size()) { + scope_stack_.emplace_back(scope_stack_.back()); + } else { + scope_stack_.emplace_back(ScopeFrame()); + } + } + + void AddDefinitionToScope(Var var) final { + if (scope_stack_.empty()) { + return; + } + + auto& shape_var_map = CurrentScopeFrame()->shape_var_map; + + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + // Expose the shape variable as non-negative, for purposes + // of shape inference. In many cases, knowning that the + // shape variable is non-negative allows for simpler + // expressions for dynamic shapes. + analyzer_.MarkGlobalNonNegValue(shape_var); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(old_shape_expr.same_as(shape_expr) || + analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; } } - scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); } void EndScope() final { scope_stack_.pop_back(); } @@ -236,6 +261,8 @@ class BlockBuilderImpl : public BlockBuilderNode { cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. + + AddDefinitionToScope(var); return var; } @@ -271,6 +298,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); + AddDefinitionToScope(match_cast->var); } else { LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); } @@ -831,7 +859,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { @@ -843,15 +873,18 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor> params = NullOpt) { + if (params.defined()) { + this->BeginScope(params.value()); + } else { + this->BeginInnerScope(); + } + + Expr ret; + // SeqExpr do not need to prepare for normalization. if (expr.as()) { - this->BeginScope(params); - Expr ret = this->VisitExpr(expr); - this->EndScope(); - return ret; + ret = this->VisitExpr(expr); } else { - this->BeginScope(params); - this->BeginBindingBlock(); Expr post = this->NormalizeArgument(expr); BindingBlock prologue = this->EndBlock(); @@ -868,9 +901,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody))); - this->EndScope(); - return seq; + ret = seq; } + + this->EndScope(); + return ret; } Array FlattenBlocks(const Array& blocks) { diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc new file mode 100644 index 000000000000..fb08dfe96a17 --- /dev/null +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_block_rewriter.cc + * \brief A transform to match a Relax DataflowBlock and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" +#include "dataflow_rewriter.h" + +namespace tvm { +namespace relax { + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } +}; + +struct PNode { + const DFPatternNode* ptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + std::vector children; + std::vector parents; +}; + +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } + + void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } + + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); + validated_constraints_.merge(other.validated_constraints_); + } + + const VarNode* matched(const PNode* p) const { + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; + } + return nullptr; + } + + const DFPatternNode* matched(const RNode* r) const { + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } + return nullptr; + } + + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } + + bool is_validated(const DFConstraintNode* constraint) const { + return validated_constraints_.count(constraint); + } + + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; + std::unordered_set validated_constraints_; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + + MatchState new_match; + + new_match.add(&p, &r); + + // forward matching; + for (const auto& [pchild, constraints] : p.children) { + bool any_cons_sat = false; + for (const auto& rchild : r.children) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. + continue; + } + + const auto& uses = ud_analysis.def2use.at(r.ptr); + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (cons.index != -1) { + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass || new_match.matched(pchild)) continue; + any_cons_sat = true; + + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); + } + } + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; + } + + return new_match; +} + +static std::optional TryValidate( + const MatchState& current_match, + const std::unordered_map& pattern2node, + const std::vector& validation_constraints, arith::Analyzer* analyzer) { + MatchState new_match; + + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + auto it = pattern2node.find(pattern); + ICHECK(it != pattern2node.end()) + << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << ", which does not appear in the PatternContext"; + const auto& p_node = it->second; + if (auto ptr = current_match.matched(p_node)) { + return GetRef(ptr); + } else { + return NullOpt; + } + }; + + for (const auto& constraint : validation_constraints) { + if (!current_match.is_validated(constraint.get())) { + auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); + + necessary_condition = analyzer->Simplify(necessary_condition); + const auto* known = tir::as_const_int(necessary_condition); + + if (known && *known && is_sufficient) { + // The condition passes, and the expression provided is both + // necessary and sufficient for the constraint to pass. Mark + // the constraint as passing, to avoid re-checking it unless + // we backtrack. + new_match.add(constraint.get()); + } else if (known && !*known) { + // The condition fails. Even if additional information would + // be required to pass a constraint, it may bail out early as + // a failure (e.g. shape mismatch in the first two items out + // of N shapes that must all match). + return std::nullopt; + } else if (is_sufficient) { + // The condition depends on dynamic parameters. In the + // future, this may be exposed to the user as a condition for + // optimization, or can be combined with the conditions + // provided from other constraints. + return std::nullopt; + } + } + } + + return new_match; +} + +static std::optional MatchTree( + const MatchState& current_match, size_t current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const std::vector& validation_constraints, + const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { + auto get_next_root = [&](size_t root_idx) -> const PNode* { + // Look for the next unmatched root node. + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_match.matched(root)) { + return &root; + } + } + return nullptr; + }; + + const auto root = get_next_root(current_root_idx); + + if (!root) { + // All root nodes have been matched + return current_match; + } + + MatchState new_match = current_match; + + for (const auto& var : ud_analysis.vars) { + const RNode& r_node = var2node.at(var); + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { + // Recursively try to match the next subtree. + new_match.add(std::move(*match)); + if (auto validation = + TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { + new_match.add(std::move(*validation)); + if (auto match_rec = + MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, + validation_constraints, ud_analysis, analyzer)) { + new_match.add(std::move(*match_rec)); + return new_match; + } + } + // Recursive matching has failed, backtrack. + new_match = current_match; + continue; + } + } + + return std::nullopt; +} + +Optional> MatchGraph(const PatternContext& ctx, + const Array& binding_arr, + const Map& bindings) { + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + DFPatternMatcher matcher(bindings); + + MatcherUseDefAnalysis ud_analysis; + for (const auto& binding : binding_arr) { + ud_analysis.VisitBinding(binding); + } + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(bindings.size()); + + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = ud_analysis.def2use.at(cur_var); + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->edge_constraints.size()); + + for (const auto& def_pattern : ctx->src_ordered) { + PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->edge_constraints.at(def_pattern); + def_node.ptr = def_pattern.get(); + def_node.children.reserve(uses.size()); + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); + } + } + + if (roots.empty()) { + return NullOpt; + } + + arith::Analyzer analyzer; + auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, + ctx->validation_constraints, ud_analysis, &analyzer); + if (!match) { + return NullOpt; + } + + Map ret; + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + } + return ret; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") + .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb); + }); + +class PatternContextRewriterNode : public PatternMatchingRewriterNode { + public: + PatternContext pattern; + TypedPackedFunc(Map, Map)> rewriter_func; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = rewriter_func; + visitor->Visit("rewriter_func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); + + private: + Optional> MatchBindings(const Array& bindings) const { + Map var_lookup; + for (const auto& binding : bindings) { + var_lookup.Set(binding->var, GetBoundValue(binding)); + } + + if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { + Map replacements = rewriter_func(matches.value(), var_lookup); + if (replacements.size()) { + return replacements; + } + } + + return NullOpt; + } +}; + +class PatternContextRewriter : public PatternMatchingRewriter { + public: + PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, + PatternContextRewriterNode); +}; + +RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { + std::vector remaining_bindings{bindings.begin(), bindings.end()}; + + Map variable_rewrites; + while (auto opt = MatchBindings(remaining_bindings)) { + auto new_rewrites = opt.value(); + remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), + [&new_rewrites](const Binding& binding) { + return new_rewrites.count(binding->var); + }), + remaining_bindings.end()); + for (const auto& [var, expr] : new_rewrites) { + variable_rewrites.Set(var, expr); + } + } + + return RewriteSpec{variable_rewrites, {}}; +} + +PatternContextRewriter::PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->rewriter_func = std::move(rewriter_func); + data_ = std::move(node); +} + +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + // return BlockPatternRewriter::Run(ctx, rewriter, func); + return Downcast(PatternContextRewriter(ctx, rewriter)(func)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc new file mode 100644 index 000000000000..514116c5cadf --- /dev/null +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -0,0 +1,1079 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_expr_rewriter.cc + * \brief A transform to match a Relax Expr and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../transform/utils.h" +#include "dataflow_matcher.h" +#include "dataflow_rewriter.h" + +namespace tvm { +namespace relax { + +namespace { +class GlobalVarReplacer : public ExprMutator { + public: + explicit GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* op) override { + auto gvar = GetRef(op); + if (auto opt = gvar_map_.Get(gvar)) { + gvar = opt.value(); + } + return gvar; + } + + private: + Map gvar_map_; +}; + +Array TopologicalSort(const Array& bindings) { + std::unordered_set remaining_bindings; + for (const auto& binding : bindings) { + remaining_bindings.insert(binding->var); + } + + // Utility structure used to track bindings that are moved later in + // the list. + struct DelayedBinding { + Binding binding; + std::unordered_set unmet_requirements; + bool emitted; + }; + std::vector delayed_bindings; + Array sorted_bindings; + + // Utility function to append the + auto push_sorted_binding = [&](Binding binding) { + sorted_bindings.push_back(binding); + remaining_bindings.erase(binding->var); + for (auto& delayed_binding : delayed_bindings) { + delayed_binding.unmet_requirements.erase(binding->var); + } + }; + + bool required_sorting = false; + for (const auto& binding : bindings) { + // Collect any variables used by this binding, but are emitted by + // a later binding. + std::unordered_set unmet_requirements; + for (auto free_var : FreeVars(GetBoundValue(binding))) { + if (remaining_bindings.count(free_var)) { + unmet_requirements.insert(free_var); + } + } + + if (unmet_requirements.empty()) { + push_sorted_binding(binding); + } else { + required_sorting = true; + delayed_bindings.push_back(DelayedBinding{binding, unmet_requirements, false}); + } + + bool requires_delayed_binding_check = true; + while (requires_delayed_binding_check) { + requires_delayed_binding_check = false; + for (auto& delayed_binding : delayed_bindings) { + if (!delayed_binding.emitted && delayed_binding.unmet_requirements.empty()) { + // If we find a delayed binding that can be emitted, mark it + // as emitted and push to the sorted list. This may + delayed_binding.emitted = true; + requires_delayed_binding_check = true; + push_sorted_binding(delayed_binding.binding); + + // The break is not necessary for a topological sort, but is + // necessary to minimize the amount of re-ordering that is + // performed. With this break, the next binding is always + // the earliest binding that is legal to emit at this point. + break; + } + } + } + + // Remove any delayed bindings that have been emitted, now that we + // are done iterating over the delayed bindings. + delayed_bindings.erase( + std::remove_if(delayed_bindings.begin(), delayed_bindings.end(), + [](const auto& delayed_binding) { return delayed_binding.emitted; }), + delayed_bindings.end()); + } + + // All bindings should be emitted by this point. If any remain, + // then there exists a circular dependency somewhere in the + // remaining bindings. + CHECK(delayed_bindings.empty()) << "ValueError: " + << "Bindings contain circular dependency"; + + if (required_sorting) { + return sorted_bindings; + } else { + return bindings; + } +} +} // namespace + +void RewriteSpec::Append(RewriteSpec other) { + if (variable_rewrites.empty()) { + *this = std::move(other); + return; + } + if (other.variable_rewrites.empty()) { + return; + } + + NameSupply gvar_name_supply(""); + for (const auto& [gvar, func] : new_subroutines) { + gvar_name_supply->ReserveName(gvar->name_hint); + } + + Map gvar_rewrites; + for (auto [gvar, func] : other.new_subroutines) { + if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { + // The two rewrites provide the same GlobalVar. + // (e.g. Multiple rewrites of the same pattern.) Ensure that + // they are referring to the same underlying BaseFunc. + CHECK(func.same_as((*it).second)); + } else if (auto new_name = gvar_name_supply->FreshName(gvar->name_hint); + new_name != gvar->name_hint) { + // The two rewrites provide distinct GlobalVar subroutines, + // but with conflicting names. Because an IRModule must have + // enough names for each GlobalVar, even if they are not + // publicly exposed, one of the GlobalVars must be replaced. + // Replacing the GlobalVar here, when the conflict is first + // identified, minimizes the size of the `relax::Expr` that + // must be updated with `GlobalVarReplacer`. + GlobalVar new_gvar = gvar; + new_gvar.CopyOnWrite()->name_hint = new_name; + gvar_rewrites.Set(gvar, new_gvar); + new_subroutines.Set(new_gvar, func); + } else { + new_subroutines.Set(gvar, func); + } + } + + for (auto [var, expr] : other.variable_rewrites) { + if (gvar_rewrites.size()) { + expr = GlobalVarReplacer(gvar_rewrites)(expr); + } + variable_rewrites.Set(var, expr); + } +} + +TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return PatternMatchingRewriter::FromPattern(pattern, func); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule").set_body_typed([](IRModule mod) { + return PatternMatchingRewriter::FromModule(mod); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") + .set_body_typed([](PatternMatchingRewriter rewriter, + Variant obj) -> Variant { + if (auto expr = obj.as()) { + return rewriter(expr.value()); + } else if (auto mod = obj.as()) { + return rewriter(mod.value()); + } else { + LOG(FATAL) << "Unreachable: object does not contain either variant type"; + } + }); + +TVM_REGISTER_NODE_TYPE(ExprPatternRewriterNode); + +RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindings) const { + Map variable_rewrites; + Map binding_lookup; + for (const auto& binding : bindings) { + auto bound_value = GetBoundValue(binding); + if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { + variable_rewrites.Set(binding->var, new_expr.value()); + } else { + binding_lookup.Set(binding->var, bound_value); + } + } + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, + const Map& bindings) const { + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { + auto matches = opt_matches.value(); + if (additional_bindings) { + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings); + for (const auto& pat : additional_bindings.value()) { + matches.Set(pat, matched_expr); + } + } + + Optional rewritten_expr = func(expr, matches); + if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { + return rewritten_expr.value(); + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return ExprPatternRewriter(pattern, func); + }); + +ExprPatternRewriter::ExprPatternRewriter( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OrRewriterNode); + +RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { + auto lhs_match = lhs->RewriteBindings(bindings); + if (!lhs_match) { + // If no rewrites found on LHS, RHS is allowed to modify any + // variable binding. + return rhs->RewriteBindings(bindings); + } + + // The LHS matched some subset of the bindings. These + // replacements may not be normalized expressions, so the RHS may + // only replace variable bindings that haven't been modified by + // the LHS. Variable replacements from the RHS may still occur, + // but will need to wait for the next round of + // iterate-until-converged. + Array remaining_bindings; + for (const auto& binding : bindings) { + if (!lhs_match.variable_rewrites.count(binding->var)) { + remaining_bindings.push_back(binding); + } + } + + if (remaining_bindings.empty()) { + // Early bail-out, the RHS has no bindings available to rewrite. + return lhs_match; + } + + lhs_match.Append(rhs->RewriteBindings(remaining_bindings)); + return lhs_match; +} + +TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter") + .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + return OrRewriter(lhs, rhs); + }); + +OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + auto node = make_object(); + node->lhs = std::move(lhs); + node->rhs = std::move(rhs); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleRewriterNode); + +RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { + CHECK_LE(patterns.size(), 3) << "For performance reasons, " + << "matching of implicit tuple patterns is currently limited" + << " to tuples with 3 elements or fewer."; + Map variable_rewrites = GenerateVariableRewrites(bindings); + + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { + Map rewrites; + + Map binding_lookup; + + std::vector info_vec; + + std::unordered_map binding_index_lookup; + + // Initialize a vector of indices, each of which corresponds to a + // potential match for a tuple element. + // + // \param tuple_index_of_current_expr The index for the most recent + // binding. + // + // \param indices An output vector, into which indices will be + // generated. + // + // \returns bool True if the indices could be initialized to a + // potential match. False, otherwise. + auto initialize_indices = [&](size_t tuple_index_of_current_expr, + std::vector& indices) -> bool { + if (!info_vec.back().matches[tuple_index_of_current_expr]) { + return false; + } + + indices = std::vector(patterns.size(), info_vec.size()); + + indices[tuple_index_of_current_expr] = info_vec.size() - 1; + + for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) { + size_t i = indices.size() - i_rev - 1; + if (indices[i] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + if (indices[i] == info_vec.size() - 1) { + return info_vec.size() - 1; + } + + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + auto decrement_indices = [&](std::vector& indices) -> bool { + ICHECK_EQ(indices.size(), patterns.size()); + + // Step 1, find the first index that can be decremented, while + // still generating a valid set of indices. + size_t i_forward; + for (i_forward = 0; i_forward < indices.size(); i_forward++) { + if (indices[i_forward] == info_vec.size() - 1) { + continue; + } + + bool found_valid = false; + size_t& index = indices[i_forward]; + while (index) { + index--; + if (info_vec[index].matches[i_forward] && !info_vec[index].used && + std::all_of( + indices.begin() + (i_forward + 1), indices.end(), + [index](size_t later_binding_index) { return index != later_binding_index; })) { + found_valid = true; + break; + } + } + if (found_valid) { + break; + } + } + + // Step 2, if we reached the end, then all indices were + // decremented to zero without finding anything. Return false to + // indicate that we've reached the end. + if (i_forward == indices.size()) { + return false; + } + + // Step 3, refill all indices that were decremented to zero before from 0 to + for (size_t i = 0; i < i_forward; i++) { + size_t i_backward = i_forward - (i + 1); + if (indices[i_backward] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i_backward] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i_backward] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) { + const auto& binding = bindings[i_binding]; + + auto expr = GetBoundValue(binding); + + binding_index_lookup[binding->var] = i_binding; + + info_vec.push_back(VarInfo{ + binding->var, + expr, + patterns.Map( + [&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }), + std::unordered_set(), + false, + }); + + auto new_match = [&]() -> std::optional, std::vector>> { + std::vector indices; + for (size_t i = 0; i < patterns.size(); i++) { + if (initialize_indices(patterns.size() - i - 1, indices)) { + do { + if (auto match = TryMatchByBindingIndex(info_vec, indices)) { + return std::pair{indices, match.value()}; + } + } while (decrement_indices(indices)); + } + } + return std::nullopt; + }(); + + if (new_match) { + const auto& [indices, exprs] = new_match.value(); + ICHECK_EQ(indices.size(), exprs.size()); + for (size_t i = 0; i < indices.size(); i++) { + ICHECK_LT(indices[i], info_vec.size()); + auto& info = info_vec[indices[i]]; + + ICHECK(!info.used) << "InternalError: " + << "Produced multiple replacements for variable " << info.var; + + rewrites.Set(info.var, exprs[i]); + binding_lookup.erase(info.var); + info.used = true; + } + } else { + binding_lookup.Set(binding->var, expr); + } + + for (const auto& prev_var : FreeVars(expr)) { + if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) { + info_vec[it->second].downstream_usage.insert(binding->var); + } + } + } + + return rewrites; +} + +std::optional> TupleRewriterNode::TryMatchByBindingIndex( + const std::vector& info_vec, const std::vector& indices) const { + ICHECK_GE(indices.size(), 1); + + ICHECK_EQ(indices.size(), patterns.size()); + for (size_t i = 0; i < indices.size(); i++) { + const auto& info = info_vec[indices[i]]; + if (info.used || !info.matches[i]) { + return std::nullopt; + } + } + + Map merged_matches = info_vec[indices[0]].matches[0].value(); + for (size_t i = 1; i < indices.size(); i++) { + for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { + if (auto it = merged_matches.find(pat); it != merged_matches.end()) { + if (!StructuralEqual()(expr, (*it).second)) { + return std::nullopt; + } + } else { + merged_matches.Set(pat, expr); + } + } + } + + bool tuple_element_is_already_used_outside_of_matched_tuple = [&]() -> bool { + std::unordered_set matched_vars; + for (const auto& [pat, expr] : merged_matches) { + if (auto opt = expr.as()) { + matched_vars.insert(opt.value()); + } + } + + for (size_t index : indices) { + const auto& downstream_of_rewritten_var = info_vec[index].downstream_usage; + + for (const auto& uses_matched_var : downstream_of_rewritten_var) { + if (!matched_vars.count(uses_matched_var)) { + return true; + } + } + } + + return false; + }(); + if (tuple_element_is_already_used_outside_of_matched_tuple) { + return std::nullopt; + } + + auto full_tuple = [&]() -> relax::Expr { + Array fields; + for (size_t index : indices) { + fields.push_back(info_vec[index].expr); + } + return relax::Tuple(fields); + }(); + + auto opt_rewritten = func(full_tuple, merged_matches); + if (!opt_rewritten) { + return std::nullopt; + } + auto rewritten = opt_rewritten.value(); + + if (rewritten.same_as(full_tuple)) { + return std::nullopt; + } + + std::vector rewrites; + if (auto inline_tuple = rewritten.as()) { + const auto& fields = inline_tuple->fields; + CHECK_EQ(fields.size(), indices.size()) + << "Expected to receive " << indices.size() << " values to replace TuplePattern with " + << indices.size() << " fields, but received " << fields.size() << " values"; + rewrites = {fields.begin(), fields.end()}; + } else { + for (size_t i = 0; i < indices.size(); i++) { + rewrites.push_back(TupleGetItem(rewritten, i)); + } + } + return rewrites; +} + +TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") + .set_body_typed([](Array patterns, + TypedPackedFunc(Expr, Map)> func) { + return TupleRewriter(patterns, func); + }); + +TupleRewriter::TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->patterns = std::move(patterns); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +PatternMatchingRewriter PatternMatchingRewriter::FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + if (auto or_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + return OrRewriter(PatternMatchingRewriter::FromPattern( + or_pattern->left, func, new_additional_bindings, new_subroutines), + PatternMatchingRewriter::FromPattern( + or_pattern->right, func, new_additional_bindings, new_subroutines)); + } else if (auto tuple_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + // If the Tuple appears as a Relax binding, apply it first. As a + // fallback, also check for implicit tuples. + return OrRewriter( + ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines), + TupleRewriter(tuple_pattern->fields, func, new_additional_bindings, new_subroutines)); + } else { + return ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines); + } +} + +PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { + Function func_pattern = [&]() { + CHECK(mod->ContainGlobalVar("pattern")) + << "KeyError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the module did not contain a 'pattern' function."; + auto base_func = mod->Lookup("pattern"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the 'pattern' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + Function func_replacement = [&]() { + CHECK(mod->ContainGlobalVar("replacement")) + << "KeyError: " + + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be matched, " + << "but the module did not contain a 'replacement' function."; + auto base_func = mod->Lookup("replacement"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be made on a successful match, " + << "but the 'replacement' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + + Map new_subroutines; + for (const auto& [gvar, func] : mod->functions) { + if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + CHECK(!is_public) << "ValueError: " + << "Expected module to have no publicly-exposed functions " + << "other than 'pattern' and 'replacement'. " + << "However, function '" << gvar->name_hint << "' of type " + << func->GetTypeKey() << " is publicly exposed."; + new_subroutines.Set(gvar, func); + } + } + + auto sinfo_pattern = GetStructInfo(func_pattern); + auto sinfo_replacement = GetStructInfo(func_replacement); + CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement)) + << "ValueError: " + << "The pattern and replacement must have the same signature, " + << "but the pattern has struct info " << sinfo_pattern + << ", while the replacement has struct info " << sinfo_replacement; + + Array param_wildcards; + Map pattern_lookup; + for (const auto& param : func_pattern->params) { + WildcardPattern wildcard; + param_wildcards.push_back(wildcard); + pattern_lookup.Set(param, StructInfoPattern(wildcard, GetStructInfo(param))); + } + + std::function make_pattern = [&](Expr expr) -> DFPattern { + if (auto var = expr.as()) { + return pattern_lookup[var.value()]; + + } else if (auto call = expr.as()) { + auto op = make_pattern(call->op); + auto args = call->args.Map(make_pattern); + return CallPattern(op, args); + + } else if (auto tuple = expr.as()) { + auto fields = tuple->fields.Map(make_pattern); + return TuplePattern(fields); + + } else if (auto tuple_get_item = expr.as()) { + auto tuple = make_pattern(tuple_get_item->tuple); + return TupleGetItemPattern(tuple, tuple_get_item->index); + + } else if (auto op = expr.as()) { + return ExprPattern(op.value()); + + } else if (auto func = expr.as()) { + return ExternFuncPattern(func->global_symbol); + + } else if (auto prim = expr.as()) { + return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); + + } else { + LOG(FATAL) << "TypeError: " + << "Cannot convert Relax expression of type " << expr->GetTypeKey() + << " into pattern-matching rule."; + } + }; + + for (const auto& block : func_pattern->body->blocks) { + for (const auto& binding : block->bindings) { + auto value_pattern = make_pattern(GetBoundValue(binding)); + if (auto match_cast = binding.as()) { + value_pattern = StructInfoPattern(value_pattern, match_cast->struct_info); + } + pattern_lookup.Set(binding->var, value_pattern); + } + } + + DFPattern top_pattern = make_pattern(func_pattern->body->body); + + TypedPackedFunc(Expr, Map)> rewriter_func = + [param_wildcards = std::move(param_wildcards), + orig_func_replacement = std::move(func_replacement)]( + Expr expr, Map matches) -> Optional { + auto func_replacement = CopyWithNewVars(orig_func_replacement); + + Array new_blocks; + + Array wildcard_bindings; + ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); + for (size_t i = 0; i < param_wildcards.size(); i++) { + Expr matched_expr = matches[param_wildcards[i]]; + + // Introduce an intermediate variable, to ensure that the + // MatchCast's target will be a Var, even for expressions that + // wouldn't normally be normalized into a variable. + Var intermediate_var("intermediate_var", GetStructInfo(matched_expr)); + wildcard_bindings.push_back(VarBinding(intermediate_var, matched_expr)); + wildcard_bindings.push_back( + MatchCast(func_replacement->params[i], intermediate_var, GetStructInfo(matched_expr))); + } + + new_blocks.push_back(DataflowBlock(wildcard_bindings)); + + for (const auto& block : func_replacement->body->blocks) { + new_blocks.push_back(block); + } + + return SeqExpr(new_blocks, func_replacement->body->body); + }; + + return PatternMatchingRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); +} + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + return matcher.GetMemo(); +} + +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class PatternMatchingMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) : rewriter_(rewriter) {} + + Map GetNewSubroutines() const { return new_subroutines_; } + + Expr VisitExpr_(const SeqExprNode* seq) override { + SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); + + StructuralEqual struct_equal; + + while (auto opt = TryRewriteSeqExpr(prev)) { + SeqExpr next = Downcast(builder_->Normalize(opt.value())); + if (struct_equal(prev, next)) { + break; + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + break; + } + + prev = next; + } + + return prev; + } + + Optional TryRewriteSeqExpr(const SeqExpr& seq) { + Array old_blocks = seq->blocks; + + // If the SeqExpr's output is not a variable, treat it as if it + // were the last variable binding of the last block. This + // simplifies the special handling of the SeqExpr's body. + Optional dummy_output_var = NullOpt; + if (!seq->body->IsInstance()) { + dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); + VarBinding dummy_binding(dummy_output_var.value(), seq->body); + + auto last_block = [&]() { + if (seq->blocks.size()) { + auto last_block = old_blocks.back(); + old_blocks.pop_back(); + return last_block; + } else { + return BindingBlock(Array{}); + } + }(); + + last_block.CopyOnWrite()->bindings.push_back(dummy_binding); + old_blocks.push_back(last_block); + } + + auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrites = rewriter_->RewriteBindings(orig_bindings); + if (!rewrites) return orig_bindings; + + for (auto [gvar, func] : rewrites.new_subroutines) { + new_subroutines_.Set(gvar, func); + } + + auto bindings = orig_bindings.Map([&](Binding binding) -> Binding { + if (auto new_expr = rewrites.variable_rewrites.Get(binding->var)) { + if (auto match_cast = binding.as()) { + return MatchCast(binding->var, new_expr.value(), match_cast->struct_info); + } else { + return VarBinding(binding->var, new_expr.value()); + } + } else { + return binding; + } + }); + + if (bindings.same_as(orig_bindings)) { + return orig_bindings; + } + + // The rewriter may have introduced additional dependencies + // between computations. Since pattern-matching only occurs + // within blocks that may be re-ordered, these can be resolved + // by performing a topological sort. + bindings = TopologicalSort(bindings); + + return bindings; + }; + + // Utility function to return the rewrites that should be applied + // to a given block. + auto get_rewrites = [&](BindingBlock block) -> Array { + if (block.as()) { + // Early return for DataflowBlock. Since neither control flow + // nor impure functions are allowed within the dataflow block, + // all bindings may be considered at the same time. + return rewrite_block(block->bindings); + } + + RewriteSpec rewrites; + + Array collected_bindings; + Array finalized_bindings; + + auto handle_collected_rewrites = [&]() { + if (collected_bindings.size()) { + auto bindings = rewrite_block(collected_bindings); + if (finalized_bindings.empty()) { + finalized_bindings = bindings; + } else { + for (const auto& binding : bindings) { + finalized_bindings.push_back(binding); + } + } + collected_bindings.clear(); + } + }; + + for (const auto& binding : block->bindings) { + auto value = GetBoundValue(binding); + bool is_dataflow = (!value.as()) && + (!(value.as() && IsImpureCall(Downcast(value)))); + if (is_dataflow) { + // This binding satisfies the dataflow constraints. + collected_bindings.push_back(binding); + } else { + // This binding does not satisfy the dataflow constraints. + // Any operations prior to this binding should be checked + // for pattern-match replacements. + handle_collected_rewrites(); + finalized_bindings.push_back(binding); + } + } + + // Check for rewrites in dataflow operations after the last + // non-dataflow segment. + handle_collected_rewrites(); + + return finalized_bindings; + }; + + // Utility function, check for and apply rewrites to a single + // block. + auto visit_block = [&](BindingBlock old_block) -> BindingBlock { + auto new_bindings = get_rewrites(old_block); + if (new_bindings.same_as(old_block->bindings)) { + return old_block; + } + + if (old_block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + for (const auto& binding : new_bindings) { + auto value = builder_->Normalize(GetBoundValue(binding)); + + if (binding.as()) { + builder_->EmitNormalized(VarBinding(binding->var, value)); + } else if (auto match_cast = binding.as()) { + builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast"; + } + } + return builder_->EndBlock(); + }; + + auto new_blocks = old_blocks.Map(visit_block); + if (old_blocks.same_as(new_blocks)) { + return NullOpt; + } + + // Restore the body of the SeqExpr, if needed. + auto new_body = [&]() -> Expr { + if (dummy_output_var) { + auto last_block = new_blocks.back(); + new_blocks.pop_back(); + + auto last_binding = last_block->bindings.back(); + last_block.CopyOnWrite()->bindings.pop_back(); + ICHECK(last_binding->var.same_as(dummy_output_var)); + + if (last_block->bindings.size()) { + new_blocks.push_back(last_block); + } + + return GetBoundValue(last_binding); + } else { + return seq->body; + } + }(); + + return SeqExpr(new_blocks, new_body); + } + + private: + const PatternMatchingRewriterNode* rewriter_; + Map new_subroutines_; +}; + +Expr PatternMatchingRewriter::operator()(Expr expr) { + PatternMatchingMutator mutator(get()); + auto new_expr = mutator(expr); + auto new_subroutines = mutator.GetNewSubroutines(); + CHECK_EQ(new_subroutines.size(), 0) + << "If PatternMatchingRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, PatternMatchingRewriter produced subroutines " << [&]() -> Array { + std::vector vec; + for (const auto& [gvar, func] : new_subroutines) { + vec.push_back(gvar); + } + std::sort(vec.begin(), vec.end(), + [](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; }); + return vec; + }() << "when applied to " + << "Relax expression of type " << expr->GetTypeKey(); + return new_expr; +} + +IRModule PatternMatchingRewriterNode::operator()( + IRModule mod, const tvm::transform::PassContext& pass_ctx) const { + PatternMatchingMutator mutator(this); + + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto rewritten = Downcast(mutator(func.value())); + if (!rewritten.same_as(base_func)) { + updates->Add(gvar, rewritten); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + write_ptr->Update(IRModule(mutator.GetNewSubroutines())); + } + + return mod; +} +tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { + return tvm::transform::PassInfo(0, "PatternMatchingRewriter", {}, false); +} + +Function RewriteCall(const DFPattern& pat, + TypedPackedFunc)> rewriter, Function func) { + return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c0b8d1e1df08..417a78f0d04b 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -22,6 +22,8 @@ * \brief The dataflow pattern matcher for Relax. */ +#include "dataflow_matcher.h" + #include #include #include @@ -37,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +48,6 @@ #include "../../arith/constraint_extract.h" #include "../transform/utils.h" -#include "dataflow_matcher_impl.h" namespace tvm { namespace relax { @@ -59,7 +61,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -static Expr TryGetValOfVar(Expr expr, const Map& var2val) { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { auto unwrap = [&](Expr expr) -> Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { @@ -98,16 +100,15 @@ void DFPatternMatcher::ClearMap(size_t watermark) { bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { - ICHECK_EQ(memo_[pattern].size(), 1); - return expr.same_as(memo_[pattern][0]); + return expr.same_as(memo_[pattern]); } else { PrimExpr cached_condition = symbolic_expr_condition_; size_t watermark = matched_nodes_.size(); bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { - memo_[pattern].push_back(expr); + memo_[pattern] = expr; matched_nodes_.push_back(pattern); } else { ClearMap(watermark); @@ -118,17 +119,17 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr } bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return !VisitDFPattern(op->reject, expr); } @@ -183,7 +184,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = VisitDFPattern(attr_pattern->pattern, expr); if (!matches) return matches; VLOG(1) << "considering AttrPatternNode at:\n" << expr; @@ -241,7 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); // utilities auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { @@ -351,12 +352,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* func = expr.as()) { matches = true; @@ -379,7 +380,7 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr } bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_get_item_node = expr.as()) { return (op->index == -1 || op->index == tuple_get_item_node->index) && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); @@ -388,7 +389,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* tuple_node = expr.as()) { matches = true; @@ -429,7 +430,7 @@ bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array } bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_node = expr.as()) { if (op->fields.size() == tuple_node->fields.size()) { @@ -449,7 +450,7 @@ bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const Ex return false; } - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_struct_info = GetStructInfo(expr); PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, expr_struct_info); @@ -497,7 +498,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { } bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_type = expr.as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } @@ -584,7 +585,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const ShapeExprNode* shape_expr = expr.as()) return ShapeEqual(&analyzer_, op->fields, shape_expr->values); return false; @@ -609,7 +610,7 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp } bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* extern_fn = expr.as()) { return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; } @@ -618,7 +619,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Ex bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { // constants can be binded to relax.Var as well. - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return expr.as() != nullptr; } @@ -642,631 +643,5 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { - auto bindings = bindings_opt.value_or({}); - DFPatternMatcher matcher(bindings); - - if (!matcher.Match(pattern, expr)) { - return NullOpt; - } - - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; -} - -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); - -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { - return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); - -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; - - const VarNode* cur_user_; - - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; - } - - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; - - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); - } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); - } -}; - -struct PNode { - const DFPatternNode* ptr; - std::vector&>> children; - std::vector&>> parents; -}; - -struct RNode { - const VarNode* ptr; - std::vector children; - std::vector parents; -}; - -struct MatchState { - void add(const PNode* p, const RNode* r) { - match_p_r[p] = r; - match_r_p[r] = p; - } - - void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } - - void add(MatchState&& other) { - match_p_r.merge(std::move(other.match_p_r)); - match_r_p.merge(std::move(other.match_r_p)); - validated_constraints_.merge(other.validated_constraints_); - } - - const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) { - return it->second->ptr; - } - return nullptr; - } - - const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(r); it != match_r_p.end()) { - return it->second->ptr; - } - return nullptr; - } - - const VarNode* matched(const PNode& p) const { return matched(&p); } - const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - - bool is_validated(const DFConstraintNode* constraint) const { - return validated_constraints_.count(constraint); - } - - private: - std::unordered_map match_p_r; - std::unordered_map match_r_p; - std::unordered_set validated_constraints_; -}; - -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -static std::optional TryMatch(const PNode& p, const RNode& r, - const MatchState& current_match, DFPatternMatcher* m, - const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - - MatchState new_match; - - new_match.add(&p, &r); - - // forward matching; - for (const auto& [pchild, constraints] : p.children) { - bool any_cons_sat = false; - for (const auto& rchild : r.children) { - if (new_match.matched(rchild)) { - // The child variable is already matched to other child pattern in a previous iteration. - continue; - } - if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { - // The child pattern is already matched to other variable in a earlier call to TryMatch. - continue; - } - - const auto& uses = ud_analysis.def2use.at(r.ptr); - - // check edge constraints. - bool all_cons_pass = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - all_cons_pass = false; - break; - } - - if (cons.index != -1) { - const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { - all_cons_pass = false; - break; - } - } - } - if (!all_cons_pass || new_match.matched(pchild)) continue; - any_cons_sat = true; - - if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { - new_match.add(pchild, rchild); - new_match.add(std::move(*match_rec)); - } - } - if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; - } - - return new_match; -} - -static std::optional TryValidate( - const MatchState& current_match, - const std::unordered_map& pattern2node, - const std::vector& validation_constraints, arith::Analyzer* analyzer) { - MatchState new_match; - - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { - auto it = pattern2node.find(pattern); - ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) - << ", which does not appear in the PatternContext"; - const auto& p_node = it->second; - if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); - } else { - return NullOpt; - } - }; - - for (const auto& constraint : validation_constraints) { - if (!current_match.is_validated(constraint.get())) { - auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); - - necessary_condition = analyzer->Simplify(necessary_condition); - const auto* known = tir::as_const_int(necessary_condition); - - if (known && *known && is_sufficient) { - // The condition passes, and the expression provided is both - // necessary and sufficient for the constraint to pass. Mark - // the constraint as passing, to avoid re-checking it unless - // we backtrack. - new_match.add(constraint.get()); - } else if (known && !*known) { - // The condition fails. Even if additional information would - // be required to pass a constraint, it may bail out early as - // a failure (e.g. shape mismatch in the first two items out - // of N shapes that must all match). - return std::nullopt; - } else if (is_sufficient) { - // The condition depends on dynamic parameters. In the - // future, this may be exposed to the user as a condition for - // optimization, or can be combined with the conditions - // provided from other constraints. - return std::nullopt; - } - } - } - - return new_match; -} - -static std::optional MatchTree( - const MatchState& current_match, size_t current_root_idx, - const std::unordered_map& pattern2node, - const std::unordered_map& var2node, DFPatternMatcher* matcher, - const std::vector& roots, const std::vector& validation_constraints, - const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { - auto get_next_root = [&](size_t root_idx) -> const PNode* { - // Look for the next unmatched root node. - for (; root_idx < roots.size(); ++root_idx) { - const auto& root = pattern2node.at(roots[root_idx].get()); - if (!current_match.matched(root)) { - return &root; - } - } - return nullptr; - }; - - const auto root = get_next_root(current_root_idx); - - if (!root) { - // All root nodes have been matched - return current_match; - } - - MatchState new_match = current_match; - - for (const auto& var : ud_analysis.vars) { - const RNode& r_node = var2node.at(var); - if (new_match.matched(r_node)) continue; - if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { - // Recursively try to match the next subtree. - new_match.add(std::move(*match)); - if (auto validation = - TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { - new_match.add(std::move(*validation)); - if (auto match_rec = - MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, - validation_constraints, ud_analysis, analyzer)) { - new_match.add(std::move(*match_rec)); - return new_match; - } - } - // Recursive matching has failed, backtrack. - new_match = current_match; - continue; - } - } - - return std::nullopt; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - const Map& bindings) { - // TODO(@ganler): Handle non-may external use. - ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - DFPatternMatcher matcher(bindings); - - MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); - - // First construct a graph of PNode and RNode. - std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); - - for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = ud_analysis.def2use.at(cur_var); - RNode& cur_node = var2node[cur_var]; - cur_node.ptr = cur_var; - for (const VarNode* use : uses) { - auto& use_node = var2node[use]; - use_node.ptr = use; - cur_node.children.push_back(&use_node); - use_node.parents.push_back(&cur_node); - } - } - - std::unordered_map pattern2node; - pattern2node.reserve(ctx->edge_constraints.size()); - - for (const auto& def_pattern : ctx->src_ordered) { - PNode& def_node = pattern2node[def_pattern.get()]; - const auto& uses = ctx->edge_constraints.at(def_pattern); - def_node.ptr = def_pattern.get(); - def_node.children.reserve(uses.size()); - for (const auto& [use_pattern, cons] : uses) { - PNode& use_node = pattern2node[use_pattern.get()]; - use_node.ptr = use_pattern.get(); - use_node.parents.emplace_back(&def_node, std::ref(cons)); - def_node.children.emplace_back(&use_node, std::ref(cons)); - } - } - - std::vector roots; - for (const auto& pat : ctx->src_ordered) { - if (pattern2node[pat.get()].parents.empty()) { - roots.push_back(pat); - } - } - - if (roots.empty()) { - return NullOpt; - } - - arith::Analyzer analyzer; - auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, - ctx->validation_constraints, ud_analysis, &analyzer); - if (!match) { - return NullOpt; - } - - Map ret; - for (const auto& [pat, p_node] : pattern2node) { - ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); - } - return ret; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") - .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb); - }); - -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } - - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); - - if (unemitted_vars.empty()) { - return; - } - - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } - } - } - } - - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } - - auto new_block = builder_->EndBlock(); - - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); - } - return block; - } - - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; - -/*! - * \brief Apply pattern matching to each expression, replacing - * matches with the output of a user-provided rewriter function. - */ -class ExprPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} - - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; - prev = next; - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); - } - - return node; - } - - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); - - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = TryGetValOfVar(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); - } - } - - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); - } - } - - return NullOpt; - } - - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; -}; - -Function RewriteBindings( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); - -Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher.h similarity index 91% rename from src/relax/ir/dataflow_matcher_impl.h rename to src/relax/ir/dataflow_matcher.h index a0c35ac0dead..c5d58db5b9d0 100644 --- a/src/relax/ir/dataflow_matcher_impl.h +++ b/src/relax/ir/dataflow_matcher.h @@ -18,11 +18,11 @@ */ /*! - * \file src/tvm/relax/dataflow_matcher_impl.h + * \file src/tvm/relax/dataflow_matcher.h * \brief The auxiliary data structure for dataflow matcher. */ -#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ -#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_H_ #include #include @@ -43,7 +43,10 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + Map GetMemo() { return memo_; } + + /* \brief Unwrap trivial expressions/bindings */ + static Expr UnwrapBindings(Expr expr, const Map& bindings); protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -88,7 +91,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual> memo_; + std::unordered_map memo_; var2val_t var2val_; std::vector matched_nodes_; PrimExpr symbolic_expr_condition_{Bool(true)}; @@ -99,4 +102,4 @@ class DFPatternMatcher : public DFPatternFunctor +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +struct RewriteSpec { + Map variable_rewrites; + Map new_subroutines; + + explicit operator bool() const { return variable_rewrites.size(); } + + void Append(RewriteSpec other); +}; + +class PatternMatchingRewriterNode : public tvm::transform::PassNode { + public: + virtual RewriteSpec RewriteBindings(const Array& bindings) const { + return RewriteSpec(); + } + + void VisitAttrs(AttrVisitor* visitor) {} + + IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; + tvm::transform::PassInfo Info() const override; + + static constexpr const char* _type_key = "relax.dpl.PatternMatchingRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternMatchingRewriterNode, PassNode); +}; + +class PatternMatchingRewriter : public tvm::transform::Pass { + public: + static PatternMatchingRewriter FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + static PatternMatchingRewriter FromModule(IRModule mod); + + Expr operator()(Expr expr); + using Pass::operator(); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternMatchingRewriter, Pass, PatternMatchingRewriterNode); +}; + +class ExprPatternRewriterNode : public PatternMatchingRewriterNode { + public: + DFPattern pattern; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const final; + + Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.ExprPatternRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprPatternRewriterNode, PatternMatchingRewriterNode); +}; + +class ExprPatternRewriter : public PatternMatchingRewriter { + public: + ExprPatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, + ExprPatternRewriterNode); +}; + +class OrRewriterNode : public PatternMatchingRewriterNode { + public: + PatternMatchingRewriter lhs; + PatternMatchingRewriter rhs; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("lhs", &lhs); + visitor->Visit("rhs", &rhs); + } + + static constexpr const char* _type_key = "relax.dpl.OrRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, PatternMatchingRewriterNode); +}; + +class OrRewriter : public PatternMatchingRewriter { + public: + OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs); + + TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, PatternMatchingRewriter, OrRewriterNode); +}; + +class TupleRewriterNode : public PatternMatchingRewriterNode { + public: + Array patterns; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("patterns", &patterns); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, PatternMatchingRewriterNode); + + private: + struct VarInfo { + Var var; + Expr expr; + Array>> matches; + std::unordered_set downstream_usage; + bool used = false; + }; + + Map GenerateVariableRewrites(const Array& bindings) const; + + std::optional> TryMatchByBindingIndex(const std::vector& info_vec, + const std::vector& indices) const; +}; + +class TupleRewriter : public PatternMatchingRewriter { + public: + TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_REWRITER_H_ diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a14ba1d9aaa1..6ace974985a5 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -21,6 +21,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -576,17 +578,35 @@ Function::Function(Array params, Expr body, Optional ret_struct body_sinfo = GetStructInfo(body); } - if (ret_struct_info.defined()) { - // allow body to override ret if body is more fine-grained. - if (body_sinfo.defined()) { - if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { - ret_struct_info = body_sinfo; - } - } - } else { - CHECK(body_sinfo.defined()) - << "Function do not have a return signature and body is not normalized"; - ret_struct_info = body_sinfo; + CHECK(body_sinfo.defined() || ret_struct_info.defined()) + << "Function must be constructed with either " + << "an explicit struct info for the return type, " + << "or a normalized body with struct info."; + + // Use the body's struct info if there is no explicit return type, + // or if the body may provide a more granular return type. + bool use_body_struct_info = + !ret_struct_info.defined() || + (body_sinfo && ret_struct_info && IsBaseOf(ret_struct_info.value(), body_sinfo.value())); + + if (use_body_struct_info) { + // MatchCast nodes within the body may introduce new symbolic + // variables. These are in-scope for the function body, but not + // for the function's return type. When hoisting the body's type + // to the function return type, symbolic variables may only be + // used if they were defined by the function's parameters. + auto f_shape_var_map = [&] { + auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + if (lookup.count(var)) { + return var; + } else { + return NullOpt; + } + }; + }(); + ret_struct_info = EraseToWellDefined(body_sinfo.value(), f_shape_var_map); } FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 63c74db7e33e..3ee403a25cda 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -606,8 +606,8 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr guard = this->VisitExpr(op->cond); - Expr true_b = this->VisitWithNewScope(op->true_branch); - Expr false_b = this->VisitWithNewScope(op->false_branch); + Expr true_b = this->VisitWithInnerScope(op->true_branch); + Expr false_b = this->VisitWithInnerScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { @@ -696,20 +696,24 @@ void ExprMutator::VisitBinding_(const MatchCastNode* binding) { Var new_var = this->VisitVarDef(binding->var); - if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && - new_struct_info.same_as(binding->struct_info)) { - // re-emit old binding if nothing changes - builder_->EmitNormalized(GetRef(binding)); - return; - } + MatchCast new_binding = [&]() -> MatchCast { + if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && + new_struct_info.same_as(binding->struct_info)) { + // re-emit old binding if nothing changes + return GetRef(binding); + } else { + new_value = builder_->NormalizeArgument(new_value); + new_var = WithStructInfo(new_var, new_struct_info); - new_value = builder_->NormalizeArgument(new_value); - new_var = WithStructInfo(new_var, new_struct_info); + var_remap_[binding->var->vid] = new_var; + var_remap_[new_var->vid] = new_var; - var_remap_[binding->var->vid] = new_var; - var_remap_[new_var->vid] = new_var; + return MatchCast(new_var, new_value, new_struct_info, binding->span); + } + }(); - builder_->EmitNormalized(MatchCast(new_var, new_value, new_struct_info, binding->span)); + builder_->EmitNormalized(new_binding); + builder_->AddDefinitionToScope(new_binding->var); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { @@ -800,7 +804,31 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> param } builder_->BeginScope(params); + // Outer scope only includes TIR variables that can be inferred from + // the function parameters. With context(builder_->GetAnalyzer(), constraint); + builder_->BeginInnerScope(); + // Inner scope also includes any TIR variables that are defined by + // MatchCast nodes, and are internal to the scope. + Expr ret = this->VisitExpr(expr); + + builder_->EndScope(); + + // Normalization (and the resulting StructInfo inference) of the + // expr occurs outside of the body's parameters, but inside the + // function signature's scope. This keeps variables that are + // inferable based on the function signature, to allow callers to + // propagate StructInfo across the function. + ret = builder_->Normalize(ret); + builder_->EndScope(); + return ret; +} + +Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + + builder_->BeginInnerScope(); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 12eb81ac675d..d1a9f97337de 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -29,12 +29,119 @@ #include #include #include +#include namespace tvm { namespace relax { namespace { +class SymbolicVarCanonicalizer : public ExprMutator { + public: + Expr VisitExpr_(const FunctionNode* func) override { + auto cached = known_values_; + auto output = ExprMutator::VisitExpr_(func); + known_values_ = cached; + return output; + } + + void VisitBinding_(const MatchCastNode* binding) override { + auto tir_var_map = + InferSymbolicVarMap({{binding->var, binding->value}}, builder_->GetAnalyzer()); + for (const auto& [tir_var, prim_expr] : tir_var_map) { + if (auto it = known_values_.find(tir_var); it != known_values_.end()) { + CHECK(!builder_->GetAnalyzer()->CanProve(it->second.expr != prim_expr)) + << "ValueError: " + << "MatchCast statements must be consistent. " + << "However, the definition of Relax variable " << it->second.source->var + << " implies that TIR variable " << tir_var << " is " << it->second.expr + << ", while the later definition of Relax variable " << binding->var + << " instead implies that TIR variable " << tir_var << " is " << prim_expr; + } else { + known_values_[tir_var] = KnownValue{prim_expr, GetRef(binding)}; + } + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const IfNode* op) override { + Expr guard = this->VisitExpr(op->cond); + + auto cached = known_values_; + Expr true_b = this->VisitWithInnerScope(op->true_branch); + known_values_ = cached; + Expr false_b = this->VisitWithInnerScope(op->false_branch); + known_values_ = cached; + + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } + + // The two branches may have had different TIR variables inlined. + // For example, one branch has a dynamic implementation and + // produces `R.Tensor([M,N])`, while the other branch checks if + // `N==16` and produces `R.Tensor([M,16])`. After the branch, the + // output is `R.Tensor([M,N])`. However, the `GetStructLCA` would + // correctly return `R.Tensor(ndim=2)`, removing all shape + // information. + // + // Since we know the StructInfo prior to replacing TIR variables, + // this pass can provide a better StructInfo than the generic + // handling in ExprMutator, by restoring the symbolic variables + // within each branch. + auto new_sinfo = VisitExprDepStructInfoField(Downcast(op->struct_info_)); + + StructuralEqual struct_equal; + if (!struct_equal(new_sinfo, GetStructInfo(true_b))) { + auto output_var = Var("then_branch_with_dyn", new_sinfo); + + true_b = SeqExpr({BindingBlock({ + MatchCast(output_var, true_b, new_sinfo), + })}, + output_var); + } + + if (!struct_equal(new_sinfo, GetStructInfo(false_b))) { + auto output_var = Var("else_branch_with_dyn", new_sinfo); + + false_b = SeqExpr({BindingBlock({ + MatchCast(output_var, false_b, new_sinfo), + })}, + output_var); + } + + return If(guard, true_b, false_b, op->span); + } + + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { + if (known_values_.empty()) { + return expr; + } + PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> Optional { + if (auto it = known_values_.find(var); it != known_values_.end()) { + return it->second.expr; + } else { + return NullOpt; + } + }); + if (output.same_as(expr)) { + return expr; + } + + output = builder_->GetAnalyzer()->Simplify(output); + return output; + } + + private: + struct KnownValue { + PrimExpr expr; + MatchCast source; + }; + + std::unordered_map known_values_; +}; + struct CanonicalizationPlan { Map replace_usage; Map replace_binding; @@ -377,16 +484,39 @@ class BindingCanonicalizer : public ExprMutator { }; } // namespace -Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); } +Expr CanonicalizeTIRVariables(Expr expr) { return SymbolicVarCanonicalizer()(std::move(expr)); } + +Expr CanonicalizeRelaxBindings(Expr expr) { return BindingCanonicalizer::Apply(std::move(expr)); } + +Expr CanonicalizeBindings(Expr expr) { + expr = CanonicalizeTIRVariables(std::move(expr)); + expr = CanonicalizeRelaxBindings(std::move(expr)); + return expr; +} namespace transform { +Pass CanonicalizeTIRVariables() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeTIRVariables(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeTIRVariables", {}); +} + +Pass CanonicalizeRelaxBindings() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeRelaxBindings", {}); +} + Pass CanonicalizeBindings() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeBindings(f)); - }; - return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); + return tvm::transform::Sequential( + { + CanonicalizeTIRVariables(), + CanonicalizeRelaxBindings(), + }, + "CanonicalizeBindings"); } TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 5755e118541f..932dca30a110 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -420,7 +420,7 @@ Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false); * * \ret The canonicalized expression */ -Expr CanonicalizeBindings(const Expr& expr); +Expr CanonicalizeBindings(Expr expr); /* \brief Remove use of trivial bindings * diff --git a/src/relax/utils.cc b/src/relax/utils.cc index f0239e424f30..77416dc92b1d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -122,11 +122,7 @@ tvm::Map InferSymbolicVarMap( if (!var_sinfo) return; auto expr_sinfo = expr.as(); - CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; - CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype) - << "Cannot bind expression with struct type " << expr << " to variable with struct type " - << var << ", due to conflicting PrimExpr DataType"; + if (!expr_sinfo) return; if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; @@ -139,15 +135,12 @@ tvm::Map InferSymbolicVarMap( if (!var_shape->values.defined()) return; auto expr_shape = expr.as(); - CHECK(expr_shape) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_shape) return; if (!expr_shape->values.defined()) return; auto var_shape_arr = var_shape->values.value(); auto expr_shape_arr = expr_shape->values.value(); - CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size()) - << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size() - << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size(); + if (var_shape_arr.size() != expr_shape_arr.size()) return; for (size_t i = 0; i < var_shape_arr.size(); i++) { bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]); } @@ -159,8 +152,7 @@ tvm::Map InferSymbolicVarMap( if (!var_tensor->shape.defined()) return; auto expr_tensor = expr.as(); - CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_tensor) return; if (!expr_tensor->shape.defined()) return; bind_from_shape(GetStructInfo(var_tensor->shape.value()), diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 792331dda4c0..3153c0770e38 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -46,6 +46,11 @@ void SeqExprFrameNode::EnterWithScope() { BindingBlock()->EnterWithScope(); } +void FunctionFrameNode::EnterWithScope() { + this->block_builder->BeginScope(params); + SeqExprFrameNode::EnterWithScope(); +} + void FunctionFrameNode::ExitWithScope() { using ir::IRModuleFrame; using tvm::relax::Expr; @@ -54,7 +59,7 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; - this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); // if the function is not private, add a global symbol to its attributes if (!is_private.value_or(Bool(false))->value && name.defined() && diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 2e94ae420a97..453c7fdb5522 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -70,15 +70,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf FunctionFrame frame = FindFunctionFrame("R.Arg"); tvm::relax::Var var(name, struct_info); frame->params.push_back(var); - - // This constraint would normally be provided as part of - // `BlockBuilder::BeginScope`. However, because the frame and its - // scope are initialized before the arguments are known, the scope - // doesn't have access to these constraints. - auto* analyzer = frame->block_builder->GetAnalyzer(); - for (const auto& tir_var : DefinableTIRVarsInStructInfo(struct_info)) { - analyzer->MarkGlobalNonNegValue(tir_var); - } + frame->block_builder->AddDefinitionToScope(var); return var; } diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py new file mode 100644 index 000000000000..828aa92bda28 --- /dev/null +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -0,0 +1,1512 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + +import pytest + + +def test_rewrite_defined_by_ir_module(): + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function + def before(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def expected(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_missing_pattern_raises_error(): + """The rewriter must define a pattern to be matched""" + + with pytest.raises(KeyError, match="pattern"): + + @R.rewriter + class Rewriter: + @R.function + def replacement(): + return R.tuple() + + +def test_incorrect_function_type_of_pattern_raises_error(): + """The rewriter's pattern must be a Relax function""" + + with pytest.raises(TypeError, match="pattern"): + + @R.rewriter + class Rewriter: + @T.prim_func + def pattern(): + pass + + @R.function + def replacement(): + return R.tuple() + + +def test_missing_replacement_raises_error(): + """The rewriter must define a replacement""" + + with pytest.raises(KeyError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + +def test_incorrect_function_type_of_replacement_raises_error(): + """The rewriter's replacement must be a Relax function""" + + with pytest.raises(TypeError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + @T.prim_func + def replacement(): + pass + + +def test_mismatch_of_static_shapes_raises_error(): + """The pattern and replacement must accept the same shapes""" + + with pytest.raises(ValueError, match="must have the same signature"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([32])): + return A + + @R.function + def replacement(A: R.Tensor([16])): + return A + + +def test_rewriter_may_be_applied_to_ir_module(): + """A rewriter may mutate an IRModule + + The `PatternMatchingRewriter.__call__` implementation may accept + either a single Relax function, or an entire IRModule. If it is + passed an IRModule, then all functions in the `IRModule` are + updated. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = x + x + return out + + @I.ir_module + class Expected: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_be_used_as_ir_transform(): + """A rewriter may be used as a tvm.ir.transform.Pass""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor([16], "float32")): + y = x + x + return y + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = tvm.ir.transform.Sequential([Rewriter])(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_same_pattern_applied_multiple_times(): + """The pattern-match may apply multiple times""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(x: R.Tensor([16], "float32")): + y = x + x + z = y + y + return z + + @R.function(private=True) + def expected(x: R.Tensor([16], "float32")): + y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")) + z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")) + return z + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_composition_of_rewrite_rules(): + """Rewrite rules may be composed together""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A + B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = A + B + E = C * D + return E + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")) + E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")) + return E + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_recursive_rewrite_rules(): + """Rewrite rules are applied until convergence + + In this test, both the `RewriteAdd` and `RewriteMultiply` patterns + must be applied in order to produce the expected output. However, + the `RewriteMultiply` pattern relies on the expression produced by + the `RewriteAdd` pass. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(A: R.Tensor([16], "float32")): + B = A + A + return B + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32")): + B = R.call_pure_packed( + "my_optimized_mul_impl", + A, + R.const(2.0, "float32"), + sinfo_args=R.Tensor([16], "float32"), + ) + return B + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_arbitrary_dtype(): + """A pattern-match may apply to a tensor with unknown dtype + + In this test case, a pattern identifies `R.strided_slice` usage + which returns the last slice of an array, and replaces it with a + view into the input array. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M]) + last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0) + return last_slice_1d + + @R.function + def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + + # TODO(Lunderberg): Improve this syntax. A Relax + # PrimValue (e.g. `A.dtype.bits`) should be usable in any + # Relax context that accepts a `PrimExpr`. Currently, + # this requires `R.match_cast` to produce a TIR symbolic + # variable from the Relax PrimValue. + bits_per_element = T.uint8() + _ = R.match_cast( + A.dtype.bits, + R.Prim(value=bits_per_element), + ) + lanes_per_element = T.uint16() + _ = R.match_cast( + A.dtype.lanes, + R.Prim(value=lanes_per_element), + ) + + last_slice = R.memory.view( + A, + [N], + relative_byte_offset=(M - 1) + * N + * T.ceildiv( + bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8 + ), + ) + return last_slice + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32]) + A_slice_1d = R.squeeze(A_slice_2d, axis=0) + + B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P]) + B_slice_1d = R.squeeze(B_slice_2d, axis=0) + + C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16]) + C_slice_1d = R.squeeze(C_slice_2d, axis=0) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + # The pattern matches any 2-d tensor, with any data type. + # When the match's shape and dtype are both known, + # normalization and canonicalization produces a statically + # known value for `relative_byte_offset`. + # + # Relative offset is `(31 rows) * + # (16 elements/row) * + # (2 bytes/element)` + A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992) + + # The pattern can also match a 2-d tensor with dynamic + # shape. The `relative_byte_offset` uses the known + # datatype (4 bytes for each int4x8), but with dynamic + # shape variables substituted in where required. + # + # Relative offset is `((P-1) rows) * + # (Q elements/row) * + # (4 bytes/element)` + B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4) + + # The pattern can also match a 2-d tensor with static + # shape, but unknown data type. The + # `relative_byte_offset` is determined based on the known + # number of elements, and the dynamic size of each + # element. + # + # Relative offset is `(15 rows) * + # (32 elements/row) * + # (ceildiv(bits*lanes,8) bytes/element)` + C_bits_per_element = T.uint8() + C_bits_prim_value = C.dtype.bits + _ = R.match_cast( + C_bits_prim_value, + R.Prim(value=C_bits_per_element), + ) + C_lanes_per_element = T.uint16() + C_lanes_prim_value = C.dtype.lanes + _ = R.match_cast( + C_lanes_prim_value, + R.Prim(value=C_lanes_per_element), + ) + + C_slice_1d = R.memory.view( + C, + shape=[32], + relative_byte_offset=( + (C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7) + // 8 + ) + * 480, + ) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + after = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, after) + + +def test_rewrite_may_introduce_private_relax_subroutines(): + """The replacement may contain subroutines""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = Expected.subroutine(B) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_only_introduces_private_subroutines_when_required(): + """Only subroutines that are used will be added to the module + + Like `test_rewrite_may_introduce_private_relax_subroutines`, but + the rewritten function only requires some of the subroutines + provided by the rewriter. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine_add(A) + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")) + + @T.prim_func(private=True) + def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine_add(A) + C = Expected.subroutine_add(B) + return C + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_not_introduce_public_subroutines(): + """The rewriter may only introduce private functions""" + + with pytest.raises(ValueError, match="is publicly exposed"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + +def test_rewrite_branches_may_reuse_subroutine_name(): + """Each rewriter is independent, and may reuse subroutine names""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B * B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @T.prim_func(private=True) + def subroutine_1(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_of_explicit_relax_tuple(): + """The rewriter function may return a tuple + + When it occurs explicitly within the Relax function, the tuple + pattern matches against the Relax tuple, and the Relax tuple is + replaced. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + proj_tuple = (proj_A, proj_B) + out = proj_tuple[0] + proj_tuple[1] + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_output_relax_tuple(): + """The rewriter may update a tuple being returned + + Unlike most relax expressions, tuples may appear as nested + expressions. Pattern-matching should be aware of this option. + + Like `test_rewrite_of_explicit_relax_tuple`, but the tuple appears + as the return value in the function being modified. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + return (proj_A, proj_B) + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple(): + """The rewriter function may return a tuple + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_shared_wildcard(): + """Tuple elements may depend on the same input + + Here, both elements of the tuple depend on `y`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + x, + y, + z, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = B + C + out = R.multiply(lhs, rhs) + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs_rhs = R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + out = R.multiply(lhs_rhs[0], lhs_rhs[1]) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_no_rewrite_of_implicit_tuple_when_shared_wildcard_is_mismatched(): + """Tuple elements must match simultaneously + + Each element of the tuple matches individually, but the two + elements both depend on `B`. Because the first tuple element + would require `y = B`, while the second tuple element would + require `y = C`, the match fails. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + D: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = C + D + out = R.multiply(lhs, rhs) + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_implicit_tuple_may_not_introduce_extra_compute(): + """Matching of implicit tuple may not cause extra compute + + Here, the `(proj_A, proj_B)` tuple could be an implcit tuple + match, but that would repeat the computation of `proj_A`. It + would be computed once on its own, to be used for `proj_A_on_B`, + and once for computing `(proj_A, proj_B)`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16, 16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + # This function has no location at which a tuple + # `(proj_A,proj_B)` could be constructed, then unpacked. + + proj_A = R.matmul(A, state) + + # A tuple `(proj_A, proj_B)` could not be constructed at this + # location, because `proj_B` has not yet been computed. + + proj_A_on_B = R.matmul(proj_A, B) + proj_B = R.matmul(proj_A_on_B, state) + + # A tuple `(proj_A, proj_B)` could be constructed here, but a + # use-site of `proj_A` has already occurred. Implicit + # matching of a tuple is only allowed if it would replace + # every use-site of a variable. + + out = proj_A + proj_B + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_three_elements(): + """Implicit tuples may contain three elements""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(qkv: R.Tensor([12288], "float32")): + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + return (q_embed, k_embed, v) + + @R.function + def replacement(qkv: R.Tensor([12288], "float32")): + return R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + @R.function(private=True) + def before( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + @R.function(private=True) + def expected( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + embedded_qkv_tuple = R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + v = embedded_qkv_tuple[2] + q_embed = embedded_qkv_tuple[0] + k_embed = embedded_qkv_tuple[1] + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_not_reorder_across_impure_functions(): + """Matched pattern must be ordered with respect to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may not be fused, because the + impure print statement occurs between them. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + R.print(format="After matmul, before add") + state = R.add(bias, state) + R.print(format="End of function") + return state + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_occur_between_impure_functions(): + """Matched pattern may be adjacent to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may be fused, because the + pattern occurs without an impure print statement in-between. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + state = R.add(bias, state) + R.print(format="End of function") + return state + + @R.function(private=True, pure=False) + def expected( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + R.print(format="End of function") + return state + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_apply_within_conditional(): + """Rewrites may apply within to inner dataflow regions + + While dataflow regions may not contain conditionals, they may + occur within the body of conditionals. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return A + B + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + + @R.function(private=True) + def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): + if cond: + out = A + B + else: + C = A + B + out = C + B + return out + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): + if cond: + out = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + else: + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + out = R.call_pure_packed( + "my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_shape(): + """Pattern match/rewrites may be dynamic + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + return (proj_A, proj_B) + + @R.function + def replacement( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + N1 = T.int64() + N2 = T.int64() + + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_A: R.Tensor([N1], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[0], end=[N1] + ) + proj_B: R.Tensor([N2], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[N1], end=[N2 + N1] + ) + return (proj_A, proj_B) + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_A = R.strided_slice(proj_concat, axes=[0], begin=[0], end=[16]) + proj_B = R.strided_slice(proj_concat, axes=[0], begin=[16], end=[32]) + out = proj_A + proj_B + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_pattern_against_dynamic_shape(): + """A dynamic pattern may match a static shape""" + + @R.rewriter + class Rewriter: + @R.function + def pattern( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + return R.matmul(A, B) + + @R.function + def replacement( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + M = T.int64() + N = T.int64() + return R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([M, N], "float32"), + ) + + @R.function(private=True) + def before( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + D: R.Tensor([N, N * 2], "float32") = R.matmul(A, B) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.matmul(E, C) + return F + + @R.function(private=True) + def expected( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + + D: R.Tensor([N, N * 2], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([N, N * 2], "float32"), + ) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + E, + C, + sinfo_args=R.Tensor([N * 2, N], "float32"), + ) + return F + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index d513c0cf6c6d..ea3b1c249b8b 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -198,9 +198,13 @@ def test_change_shape(): @I.ir_module class TestChangeShape: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): y = x - # not trivial: introduces new shape vars + # The MatchCast is non-trivial, as it introduces new shape + # vars. Because the input tensor has an unknown shape + # rather than a symbolic shape, these new shape vars + # cannot be expressed in terms of previous variables. + # Therefore, the match cast must be retained. o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) w = z @@ -210,7 +214,7 @@ def main(x: R.Tensor(("m", "n"))): @I.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) # the struct_info field on q will need to be updated @@ -220,6 +224,35 @@ def main(x: R.Tensor(("m", "n"))): verify(TestChangeShape, Expected) +def test_replace_symbolic_variable_and_remove_match_cast(): + @I.ir_module + class TestChangeShape: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + # The MatchCast is non-trivial, as it introduces new shape + # vars. However, the new shape vars are redundant, and + # are replaced by canonicalization. After replacing the + # new shape vars, the MatchCast is trivial and may be + # removed. + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((o, p))) + w = z + q = R.add(w, y) + return R.add(q, w) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"))): + m = T.int64() + n = T.int64() + q: R.Tensor([m, n]) = R.add(x, x) + return R.add(q, x) + + verify(TestChangeShape, Expected) + + def test_unwrap_tuple(): @I.ir_module class Before: @@ -289,6 +322,222 @@ def main() -> R.Tensor((), "int32"): verify(Input, Expected) +def test_fold_variables_from_match_cast(): + """Symbolic variables in R.match_cast may be inferred + + If the argument to `R.match_cast` has known shape parameters, they + may be used to infer symbolic shape parameters. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # The symbolic variables `N1`, `N2` and `M` are defined by + # these `R.match_cast` statements. Since the inputs have + # a known shape, the values of these symbolic variables + # may be inferred. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + # The symbolic shapes propagate downstream. + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + # The function no longer depends on symbolic variables. + # Shape inference is now propagated using the + # statically-known shapes. + + lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0) + proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void") + proj_A: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(0)], + [R.prim_value(16)], + assume_inbound=False, + ) + proj_B: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(16)], + [R.prim_value(32)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + verify(Before, Expected) + + +def test_inconsistent_match_cast_raises_error(): + """Symbolic variables from R.match_cast must be consistent + + All match cast statements must provide consistent definitions for + symbolic variables. In this test, the value of `M` would be + inferred as 16 from either `state` or `A`, but would be inferred + as 32 from `B`. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([32, 32], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # These R.match_cast statements define inconsistent values + # for the symbolic shape parameters. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + with pytest.raises(ValueError, match="MatchCast statements must be consistent"): + CanonicalizeBindings()(Before) + + +def test_match_cast_may_have_distinct_values_in_branches(): + """Conditional branches may have different values of symbolic variables + + Here, the value of `N` can be inferred as 16 within the `if` + branch and as 32 within the `else` branch. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + weights: R.Tensor([M, 16], "float32") = A * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + else: + weights: R.Tensor([M, 32], "float32") = B * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + + weights: R.Tensor([M, N], "float32") = weights * scale + + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 16], "float32") = A * scale + # The scaled weights within the branch may perform + # shape inference knowing that N==16. + weights: R.Tensor([M, 16], "float32") = weights * scale + # The match cast on exiting the if branch restores the + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + else: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 32], "float32") = B * scale + # Within the else-branch, the R.match_cast implies + # that N==32. While this conflicts with the earlier + # definition, the two occur in separate branches, so + # this is legal. + # The scaled weights within the branch may perform + # shape inference knowing that N==32. + weights: R.Tensor([M, 32], "float32") = weights * scale + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + + # Outside of the conditional, we no longer have a known + # value for N, so this shape inference must be done using + # a dynamic shape for `N`. + weights: R.Tensor([M, N], "float32") = weights * scale + + # After the conditional branch, we no longer have a known + # value of N, so this shape inference must use the dynamic + # shape. + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + verify(Before, Expected) + + def test_multiple_outputs(): @I.ir_module class Input: diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index dd0208f5db07..ba5d4d7d1219 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -720,7 +720,7 @@ def reshape( T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] @R.function - def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): + def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"): x_1 = T.int64() gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64014d1c49be..4f41b662caf2 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2317,5 +2317,51 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): tvm.ir.assert_structural_equal(inferred_sinfo, expected) +def test_conditional_may_use_symbolic_variables_from_function_scope(): + """Symbolic variables from function scope may be used in branch + + This is a regression test. In earlier implementations, the + branches of `relax::If` were normalized with + `EraseToWellDefinedInScope`, using a fresh variable scope. While + this had the intended behavior of preventing variables defined in + a single branch from being usable outside of the conditional, it + also caused the conditional's branches to treat function-scope + symbolic variables as if they were undefined. + + """ + + @R.function(private=True) + def explicit_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ) -> R.Tensor(["N"], "float32"): + + N = T.int64() + + if cond: + out: R.Tensor([N], "float32") = A + B + else: + out: R.Tensor([N], "float32") = A * B + + return out + + @R.function(private=True) + def inferred_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ): + N = T.int64() + if cond: + out = A + B + else: + out = A * B + + return out + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + if __name__ == "__main__": tvm.testing.main() From 6704175fc7d427bded07e7348c230c58bd9ef75f Mon Sep 17 00:00:00 2001 From: sdalvi-quic <135273488+sdalvi-quic@users.noreply.github.com> Date: Wed, 24 Jul 2024 23:26:47 -0500 Subject: [PATCH 028/202] Pass to eliminate redundant branch and overcompute (#17170) * Implementation to eliminate redundant branch introduced due to operator padding and overcompute, this creates more opportunities to vectorize the code * Fixed lint error in transform.py file * Fixed lint errors in the file using_assume_to_reduce_branches.cc * Fixed lint error in transform.py related to line too long * Fixed Lint error related to space and length of the sentence in using_assume_to_reduce_branches.cc * Fixed lint error : trailing whitespaces in using_assume_to_reduce_breanches.cc * Fixed lint error: clang format issue in cpp files * fixed pylint errors in python files and used clang format to format the cpp files * Ran black format and removed the attr_registry_map.h import as it was running into some other issue because of which build was failing --- include/tvm/tir/transform.h | 8 + python/tvm/tir/transform/transform.py | 13 + .../using_assume_to_reduce_branches.cc | 394 +++++++++++ ...nate_pad_branch_using_buffer_assumption.py | 648 ++++++++++++++++++ 4 files changed, 1063 insertions(+) create mode 100644 src/tir/transforms/using_assume_to_reduce_branches.cc create mode 100644 tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 98edbeaceb26..a8d93bf898c4 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -834,6 +834,14 @@ TVM_DLL Pass InstrumentProfileIntrinsics(); */ TVM_DLL Pass DefaultGPUSchedule(); +/*! + * \brief This pass analyzes primfunc & eliminates branch introdued due to layout specific padding. + * It leverages from the buffer assumptions and use the information to eliminate the branch. + * \note This creates more opportunity to vectorize the code. + * \return The Pass. + */ +TVM_DLL Pass UseAssumeToReduceBranches(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..d8531401d49d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1199,3 +1199,16 @@ def DefaultGPUSchedule(): ret: tvm.transform.Pass """ return _ffi_api.DefaultGPUSchedule() # type: ignore + + +def UseAssumeToReduceBranches(): + """This pass attempts to eliminates layout specific pad branch by overcomputing the values + for padded region. Eliminating the branch will help to vectorize code, + and improve element wise ops performance. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UseAssumeToReduceBranches() # type: ignore diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc new file mode 100644 index 000000000000..2e45bb0ff8fb --- /dev/null +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file using_assume_to_reduce_branches.cc + * + * \brief Attempt to remove conditional branch statements by introducing + * extra computations that do not impact the final results. Mainly + * oriented for layout specific padding related branches. + * + * \note + * 1. This pass works if the buffer assumption variable is in the branch statement. + * In case, the buffer assumption is not present in the branch statement and + * there are intermediate buffers then, inline the code. + * 2. The assumptions leveraged here should be of the form T.assume(condition_on_indices or + * buffer_equals_to_some_value) + * 3. Some part of the code are reused from the control_flow_graph.cc file which also + * handles eliminating branches in particular scenarios. + * 4. This pass currently works for op_pattern kElemWise and kBroadcast. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../arith/constraint_extract.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/unwrap_vector_expr.h" +#include "simplify.h" +#include "tvm/ir/expr.h" +namespace tvm { +namespace tir { + +using namespace arith; + +class AssumeChecker : public StmtExprVisitor { + /* This class checks if the primfunc has assume statement. + If yes, then only the FuncAnanlyzerMutator class runs. This is to ensure speedup in the pass.*/ + public: + bool has_assume = false; + + void VisitStmt(const Stmt& stmt) final { + if (has_assume) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + void VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + has_assume = true; + } + } +}; + +class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { + /* This class analyzes the complete primfunc. + It parses the buffer assumptions and eliminates the redundant branch + introduced due to layout specific padding by leveraging from buffer assumptions. + On eliminating the branch there are more opportunities to vectorize the code + and improve performance. + + Example: + ------------- + Prim Func Before : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = T.if_then_else(if_then_else_condition, 0, function(A)) + # here function(A) is some function on Var A + + Prim Func After : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = function(A) # here function(A) is some function on the Var A + -------------- + # High-level implementation details : + 1. The pass parses the assume statement and stores the relevant information. + 2. The pass tries to evaluate the then_clause and else_clause in then_condition_context + and else_condition_context. + It checks if the context of the assume statement (for condition indices and + assume_condition) is same as the context of the if_then_else statement (for condition indices + and if_then_else condition). If context is same and the expression inside if_then_else statement + is a function of the buffer assumption (eg A in above example), + then the pass substitutes the value from the buffer assumption and simplifies the expression. + 3. The pass then checks if then_clause and else_clause evaluate to same value. + If yes, then return the else_clause if we are in the then_condition_context (since then_clause + will be true in this context and if else_clause is also evaluating to true then we can directly + replace it with else_clause), similarly, we return the then_clause if we are in the + else_condition_context. + This class handles all these scenarios.*/ + + public: + using Parent = IRMutatorWithAnalyzer; + explicit ParseAssumeAndOvercompute(Analyzer* analyzer) : Parent(analyzer) {} + + private: + using Parent::VisitExpr_; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + // This struct stores all the relevant data related to asssume statement + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + Array buffer_indices; // Storing the indices of the buffer Eg : i + }; + // List of conditions in a scope + std::vector conditions_; + + // Storing all the buffer assumptions data in map + std::map map_buffer_assumption; + tir::Buffer current_bufferstorenode_name; + + struct InternalConstraintContext { + /* This stuct appends the constraint passed to it in the conditions list. + It keeps track of the bounds of the variables along with any conditions on the variables */ + InternalConstraintContext(ParseAssumeAndOvercompute* self, PrimExpr constraint) + : self(self), analyzer_context(self->analyzer_, constraint) { + old_num_constraints = self->conditions_.size(); + + auto side_effect = tir::SideEffect(constraint); + if (side_effect <= tir::CallEffectKind::kPure) { + self->conditions_.push_back(constraint); + } else if (side_effect <= tir::CallEffectKind::kReadState) { + assume = constraint; + } + + new_num_constraints = self->conditions_.size(); + } + + ~InternalConstraintContext() { + ICHECK_EQ(self->conditions_.size(), new_num_constraints) + << "Internal error: Each condition should only be popped once."; + self->conditions_.erase(self->conditions_.begin() + old_num_constraints, + self->conditions_.end()); + } + + ParseAssumeAndOvercompute* self{nullptr}; + With analyzer_context; + size_t old_num_constraints{0}; + size_t new_num_constraints{0}; + Optional assume{NullOpt}; + + // Disable default-generated copy/move assignment and constructors + InternalConstraintContext(const InternalConstraintContext&) = delete; + InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; + InternalConstraintContext(InternalConstraintContext&&) = delete; + InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; + }; + + PrimExpr CurrentScopePredicate() const { + /* This combines all the constraints in a scope */ + PrimExpr predicate = Bool(true); + for (const auto& condition : conditions_) { + predicate = predicate && condition; + } + return predicate; + } + + Stmt VisitStmt_(const ForNode* op) final { + /* Create and delete the scope with bind. + Add the minimum and maximum bound for the variables to the conditions_ list using + InternalConstraintContext */ + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + InternalConstraintContext ctx1(this, op->loop_var >= op->min); + InternalConstraintContext ctx2(this, op->loop_var < op->min + op->extent); + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + if (map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()) { + PrimExpr buf_value; + /* If the cuurent context where the buffer load is present is same as + the context of the buffer assumption then, return the buffer value present in the assumption. + This will eventually replace the bufferload value in the complete expresison */ + + auto buffer_assumption = map_buffer_assumption[op->buffer]; + PrimExpr current_predicate_and_context = CurrentScopePredicate(); + PrimExpr buffer_predicate_and_context = + buffer_assumption.buffer_context && buffer_assumption.buffer_predicate; + bool current_context_and_buffer_constraint_is_same = StructuralEqual()( + current_predicate_and_context, buffer_predicate_and_context, /*map_free_vars=*/true); + + if (current_context_and_buffer_constraint_is_same) { + buf_value = buffer_assumption.buffer_value; + return buf_value; + } + } + return GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + + // Eliminate the builtin if_then_else statement + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::if_then_else())) { + PrimExpr cond = call->args[0]; + PrimExpr then_clause = call->args[1]; + PrimExpr else_clause = call->args[2]; + + PrimExpr then_clause_in_then_context; + PrimExpr else_clause_in_then_context; + PrimExpr then_clause_in_else_context; + PrimExpr else_clause_in_else_context; + { + // Simplifying expressions in " then context " + InternalConstraintContext then_ctx(this, cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_then_context = (*this)(then_clause); + then_clause_in_then_context = analyzer_->Simplify(then_clause_in_then_context); + + else_clause_in_then_context = (*this)(else_clause); + else_clause_in_then_context = analyzer_->Simplify(else_clause_in_then_context); + } + { + // Simplifying expressions in " else context " + InternalConstraintContext else_ctx(this, !cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_else_context = (*this)(then_clause); + then_clause_in_else_context = analyzer_->Simplify(then_clause_in_else_context); + + else_clause_in_else_context = (*this)(else_clause); + else_clause_in_else_context = analyzer_->Simplify(else_clause_in_else_context); + } + + auto n = this->CopyOnWrite(op); + if (StructuralEqual()(then_clause_in_then_context, else_clause_in_then_context)) { + n->value = analyzer_->Simplify(else_clause); + return Stmt(n); + } else if (StructuralEqual()(then_clause_in_else_context, else_clause_in_else_context)) { + n->value = analyzer_->Simplify(then_clause); + return Stmt(n); + } else { + return Parent::VisitStmt_(op); + } + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + Assume(op->args[0]); + } + return Parent::VisitExpr_(op); + } + + void Assume(PrimExpr assumption) { + for (const auto& expr : arith::ExtractConstraints(assumption, false)) { + AssumeConstraintComponent(expr); + } + } + + void AssumeConstraintComponent(PrimExpr assumption) { + PrimExpr additional_predicate = Bool(true); + assume_struct buf_data; + + std::vector buffer_exprs; + for (const auto& expr : arith::ExtractComponents(assumption)) { + auto side_effect = tir::SideEffect(expr); + if (side_effect <= tir::CallEffectKind::kPure) { + // Pulling out portions of the assumption that do not depend + // on a buffer value allows the following two forms to be + // treated identically. + // + // Option 1: if i < 3: T.assume(buf[i] == value) + // Option 2: T.assume(i>=3 or buf[i] == value) + additional_predicate = additional_predicate && logical_not(expr); + } else if (side_effect == tir::CallEffectKind::kReadState) { + buffer_exprs.push_back(expr); + } else { + LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; + } + } + + additional_predicate = analyzer_->Simplify(std::move(additional_predicate)); + CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + + auto* as_equal_node = buffer_exprs[0].as(); + CHECK(as_equal_node) << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; + if (!as_equal_node) { + // This assumption is an inequality on a data-dependent + // conditional. Not an error for this to occur, but also not + // something that is currently supported. + return; + } + + // Parse the statement and store the desired values + // Ex: A[i]==0, load = A[i], value = 0 + tir::BufferLoad load; + PrimExpr value; + if (auto opt = as_equal_node->a.as()) { + load = opt.value(); + value = as_equal_node->b; + } else if (auto opt = as_equal_node->b.as()) { + load = opt.value(); + value = as_equal_node->a; + } else { + LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + } + + // Populating the assume statement predicate, buffer, value + // and the context of the assume statement + buf_data.buffer_context = CurrentScopePredicate(); + buf_data.buffer_predicate = additional_predicate; + buf_data.buffer_load = load; + buf_data.buffer_value = value; + buf_data.buffer_indices = load->indices; + for (size_t i = 0; i < load->indices.size(); i++) { + buf_data.buffer_indices.push_back(analyzer_->Simplify(load->indices[i])); + } + map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data; + + auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + CHECK(!has_side_effect) << "Buffer value in constraint must be pure expression, but was " + << value; + if (has_side_effect) { + return; + } + } +}; + +namespace transform { + +Pass UseAssumeToReduceBranches() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + + // The pass runs & eliminates pad branch with overcompute only if, + // the primfunc has op_pattern defined and is an elementwise op. + // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. + if (n->attrs.GetAttr("op_pattern").defined()) { + Optional opt_pattern = f->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + relay::OpPatternKind pattern; + pattern = static_cast(Downcast(opt_pattern)->value); + + if (pattern == relay::OpPatternKind::kElemWise || + pattern == relay::OpPatternKind::kBroadcast) { + // If the primfunc contains assume statement then, run the mutator pass. + AssumeChecker assume_checker; + assume_checker(std::move(n->body)); + + if (assume_checker.has_assume) { + // Leverage from assume and eliminate the branch + ParseAssumeAndOvercompute func_analyzer_mutator(&analyzer); + n->body = func_analyzer_mutator(std::move(n->body)); + } + } + } + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") + .set_body_typed(UseAssumeToReduceBranches); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py new file mode 100644 index 000000000000..b8ff2b6c79b2 --- /dev/null +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -0,0 +1,648 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, unused-variable + +# The test attempts to eliminate redundant pad branch and overcompute the value for elementwise ops. +# This helps to expose more opportunities to vectorize the code. + +import tvm +import tvm.testing + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class AddBefore: + @T.prim_func(private=True) + def add( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + AddBefore.add, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class AddExpected: + @T.prim_func(private=True) + def add( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + AddExpected.add, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class SubBefore: + @T.prim_func(private=True) + def sub( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + SubBefore.sub, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class SubExpected: + @T.prim_func(private=True) + def sub( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + SubExpected.sub, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class MulBefore: + @T.prim_func(private=True) + def mul( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + MulBefore.mul, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class MulExpected: + @T.prim_func(private=True) + def mul( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + MulExpected.mul, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +def test_add_primfunc_overcompute(): + add_after = tvm.tir.transform.UseAssumeToReduceBranches()(AddBefore) + tvm.ir.structural_equal(add_after["add"], AddExpected["add"], map_free_vars=True) + + +def test_sub_primfunc_overcompute(): + sub_after = tvm.tir.transform.UseAssumeToReduceBranches()(SubBefore) + tvm.ir.structural_equal(sub_after["sub"], SubExpected["sub"], map_free_vars=True) + + +def test_mul_primfunc_overcompute(): + mul_after = tvm.tir.transform.UseAssumeToReduceBranches()(MulBefore) + tvm.ir.structural_equal(mul_after["mul"], MulExpected["mul"], map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main() From 08d75197e1033d64cba5da0407a7489759c5dba5 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 25 Jul 2024 16:44:55 +0300 Subject: [PATCH 029/202] [Cython][FFI] Fix crash when call del operator for handle (#17190) * [Cython][FFI] Fix crash when call del operator for handle In case of cython when we create a set function for property then the following code will be generated: ``` static int __pyx_setprop_4test_9TestClass_handle(PyObject *o, PyObject *v, CYTHON_UNUSED void *x) { if (v) { return __pyx_pw_4test_9TestClass_6handle_3__set__(o, v); } else { PyErr_SetString(PyExc_NotImplementedError, "__del__"); return -1; } } ``` And when we call operator `del` for this handler, then the memory will be released and operator `__set__` will be called for NULL object. In this case an exception that operator `__del__` is not implemented will be generated. To avoid this problem we need to declare `__del__` function for each property where we define operator `__set__`. * Apply comments * Set dref.handle to None instead of using __del__ functions --- python/tvm/runtime/disco/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 38c4f2a2354c..89ef549df3ee 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -92,7 +92,7 @@ class DModule(DRef): def __init__(self, dref: DRef, session: "Session") -> None: self.handle = dref.handle - del dref.handle + dref.handle = None self.session = session def __getitem__(self, name: str) -> DPackedFunc: From 1b6c00d7560afded9b5380abfd3f182461b9448d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 25 Jul 2024 21:11:33 -0700 Subject: [PATCH 030/202] [Disco] Implement SocketSession (#17182) * [Disco] Implement SocketSession Implements SocketSession that connects multiple local worker processes/threads over multiple distributed nodes via TCP socket. * doc * lint * resolve conflcit * lint * add local worker id * lint * lint * disable for hexagon * remove from header --- CMakeLists.txt | 6 + include/tvm/runtime/disco/disco_worker.h | 4 + include/tvm/runtime/disco/session.h | 1 + .../tvm/exec/disco_remote_socket_session.py | 33 ++ python/tvm/runtime/disco/__init__.py | 1 + python/tvm/runtime/disco/session.py | 23 ++ src/runtime/disco/bcast_session.h | 20 ++ src/runtime/disco/disco_worker.cc | 4 +- .../disco/distributed/socket_session.cc | 332 ++++++++++++++++++ src/runtime/disco/message_queue.h | 133 +++++++ src/runtime/disco/nccl/nccl.cc | 4 +- src/runtime/disco/process_session.cc | 128 ++----- src/runtime/disco/threaded_session.cc | 4 + src/support/socket.h | 6 +- tests/python/disco/test_session.py | 87 ++++- 15 files changed, 676 insertions(+), 110 deletions(-) create mode 100644 python/tvm/exec/disco_remote_socket_session.py create mode 100644 src/runtime/disco/distributed/socket_session.cc create mode 100644 src/runtime/disco/message_queue.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7575d6c2b4d6..7fba5355f077 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON) add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) endif() +# distributed disco runtime are disabled for hexagon +if (NOT BUILD_FOR_HEXAGON) + tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS src/runtime/disco/distributed/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS}) +endif() + # Package runtime rules if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 13f94802c886..c9c85b7dbfed 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -52,6 +52,7 @@ class DiscoWorker { explicit DiscoWorker(int worker_id, int num_workers, int num_groups, WorkerZeroData* worker_zero_data, DiscoChannel* channel) : worker_id(worker_id), + local_worker_id(worker_id), num_workers(num_workers), num_groups(num_groups), default_device(Device{DLDeviceType::kDLCPU, 0}), @@ -68,6 +69,9 @@ class DiscoWorker { /*! \brief The id of the worker.*/ int worker_id; + /*! \brief The local id of the worker. This can be different from worker_id if the session is + * consisted with multiple sub-sessions. */ + int local_worker_id; /*! \brief Total number of workers */ int num_workers; /*! \brief Total number of workers */ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 97fa79096d63..9c34f8a2af9e 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -281,6 +281,7 @@ class Session : public ObjectRef { */ TVM_DLL static Session ProcessSession(int num_workers, int num_groups, String process_pool_creator, String entrypoint); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; diff --git a/python/tvm/exec/disco_remote_socket_session.py b/python/tvm/exec/disco_remote_socket_session.py new file mode 100644 index 000000000000..3111ce30ac4b --- /dev/null +++ b/python/tvm/exec/disco_remote_socket_session.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Launch disco session in the remote node and connect to the server.""" +import sys +import tvm +from . import disco_worker as _ # pylint: disable=unused-import + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: ") + sys.exit(1) + + server_host = sys.argv[1] + server_port = int(sys.argv[2]) + num_workers = int(sys.argv[3]) + func = tvm.get_global_func("runtime.disco.RemoteSocketSession") + func(server_host, server_port, num_workers) diff --git a/python/tvm/runtime/disco/__init__.py b/python/tvm/runtime/disco/__init__.py index 856e69bc3598..2ba524cade66 100644 --- a/python/tvm/runtime/disco/__init__.py +++ b/python/tvm/runtime/disco/__init__.py @@ -22,4 +22,5 @@ ProcessSession, Session, ThreadedSession, + SocketSession, ) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 89ef549df3ee..1749942a9ca0 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -574,6 +574,29 @@ def _configure_structlog(self) -> None: func(config, os.getpid()) +@register_func("runtime.disco.create_socket_session_local_workers") +def _create_socket_session_local_workers(num_workers) -> Session: + """Create the local session for each distributed node over socket session.""" + return ProcessSession(num_workers) + + +@register_object("runtime.disco.SocketSession") +class SocketSession(Session): + """A Disco session backed by socket-based multi-node communication.""" + + def __init__( + self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.SocketSession, # type: ignore # pylint: disable=no-member + num_nodes, + num_workers_per_node, + num_groups, + host, + port, + ) + + @register_func("runtime.disco._configure_structlog") def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: """Configure structlog for all disco workers diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index 1a4df634b738..0e4ca614d418 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj { * \param TVMArgs The input arguments in TVM's PackedFunc calling convention */ virtual void BroadcastPacked(const TVMArgs& args) = 0; + + /*! + * \brief Send a packed sequence to a worker. This function is usually called by the controler to + * communicate with worker-0, because the worker-0 is assumed to be always collocated with the + * controler. Sending to other workers may not be supported. + * \param worker_id The worker id to send the packed sequence to. + * \param args The packed sequence to send. + */ + virtual void SendPacked(int worker_id, const TVMArgs& args) = 0; + /*! * \brief Receive a packed sequence from a worker. This function is usually called by the * controler to communicate with worker-0, because the worker-0 is assumed to be always @@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj { struct Internal; friend struct Internal; + friend class SocketSessionObj; + friend class RemoteSocketSession; +}; + +/*! + * \brief Managed reference to BcastSessionObj. + */ +class BcastSession : public Session { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj); }; } // namespace runtime diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index 5e6f401054ea..4007b104f252 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -120,7 +120,7 @@ struct DiscoWorker::Impl { } static void CopyFromWorker0(DiscoWorker* self, int reg_id) { - if (self->worker_zero_data != nullptr) { + if (self->worker_id == 0) { NDArray tgt = GetNDArrayFromHost(self); NDArray src = GetReg(self, reg_id); tgt.CopyFrom(src); @@ -128,7 +128,7 @@ struct DiscoWorker::Impl { } static void CopyToWorker0(DiscoWorker* self, int reg_id) { - if (self->worker_zero_data != nullptr) { + if (self->worker_id == 0) { NDArray src = GetNDArrayFromHost(self); NDArray tgt = GetReg(self, reg_id); tgt.CopyFrom(src); diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc new file mode 100644 index 000000000000..07196be3056b --- /dev/null +++ b/src/runtime/disco/distributed/socket_session.cc @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include + +#include "../../../support/socket.h" +#include "../bcast_session.h" +#include "../message_queue.h" + +namespace tvm { +namespace runtime { + +using namespace tvm::support; + +enum class DiscoSocketAction { + kShutdown = static_cast(DiscoAction::kShutDown), + kSend, + kReceive, +}; + +class DiscoSocketChannel : public DiscoChannel { + public: + explicit DiscoSocketChannel(const TCPSocket& socket) + : socket_(socket), message_queue_(&socket_) {} + + DiscoSocketChannel(DiscoSocketChannel&& other) = delete; + DiscoSocketChannel(const DiscoSocketChannel& other) = delete; + void Send(const TVMArgs& args) { message_queue_.Send(args); } + TVMArgs Recv() { return message_queue_.Recv(); } + void Reply(const TVMArgs& args) { message_queue_.Send(args); } + TVMArgs RecvReply() { return message_queue_.Recv(); } + + private: + TCPSocket socket_; + DiscoStreamMessageQueue message_queue_; +}; + +class SocketSessionObj : public BcastSessionObj { + public: + explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int num_groups, + const String& host, int port) + : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) { + const PackedFunc* f_create_local_session = + Registry::Get("runtime.disco.create_socket_session_local_workers"); + ICHECK(f_create_local_session != nullptr) + << "Cannot find function runtime.disco.create_socket_session_local_workers"; + local_session_ = ((*f_create_local_session)(num_workers_per_node)).AsObjectRef(); + DRef f_init_workers = + local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers"); + local_session_->CallPacked(f_init_workers, num_nodes_, /*node_id=*/0, num_groups, + num_workers_per_node_); + + Socket::Startup(); + socket_.Create(); + socket_.SetKeepAlive(true); + socket_.Bind(SockAddr(host.c_str(), port)); + socket_.Listen(); + LOG(INFO) << "SocketSession controller listening on " << host << ":" << port; + + TVMValue values[4]; + int type_codes[4]; + TVMArgsSetter setter(values, type_codes); + setter(0, num_nodes); + setter(1, num_workers_per_node); + setter(2, num_groups); + + for (int i = 0; i + 1 < num_nodes; ++i) { + SockAddr addr; + remote_sockets_.push_back(socket_.Accept(&addr)); + remote_channels_.emplace_back(std::make_unique(remote_sockets_.back())); + setter(3, i + 1); + // Send metadata to each remote node: + // - num_nodes + // - num_workers_per_node + // - num_groups + // - node_id + remote_channels_.back()->Send(TVMArgs(values, type_codes, 4)); + LOG(INFO) << "Remote node " << addr.AsString() << " connected"; + } + } + + int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; } + + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + return local_session_->DebugGetFromRemote(reg_id, worker_id); + } else { + std::vector values(5); + std::vector type_codes(5); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), + worker_id, static_cast(DiscoAction::kDebugGetFromRemote), reg_id, worker_id); + + remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 2); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugGetFromRemote); + TVMRetValue result; + result = args[1]; + return result; + } + } + + void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + local_session_->DebugSetRegister(reg_id, value, worker_id); + } else { + ObjectRef wrapped{nullptr}; + if (value.type_code() == kTVMNDArrayHandle || value.type_code() == kTVMObjectHandle) { + wrapped = DiscoDebugObject::Wrap(value); + TVMValue tvm_value; + int type_code = kTVMObjectHandle; + tvm_value.v_handle = const_cast(wrapped.get()); + value = TVMArgValue(tvm_value, type_code); + } + { + TVMValue values[6]; + int type_codes[6]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kSend), worker_id, + static_cast(DiscoAction::kDebugSetRegister), reg_id, worker_id, value); + remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 6)); + } + TVMRetValue result; + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 1); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugSetRegister); + } + } + + void BroadcastPacked(const TVMArgs& args) final { + local_session_->BroadcastPacked(args); + std::vector values(args.size() + 2); + std::vector type_codes(args.size() + 2); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), -1); + std::copy(args.values, args.values + args.size(), values.begin() + 2); + std::copy(args.type_codes, args.type_codes + args.size(), type_codes.begin() + 2); + for (auto& channel : remote_channels_) { + channel->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + } + } + + void SendPacked(int worker_id, const TVMArgs& args) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + local_session_->SendPacked(worker_id, args); + return; + } + std::vector values(args.size() + 2); + std::vector type_codes(args.size() + 2); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), + worker_id); + std::copy(args.values, args.values + args.size(), values.begin() + 2); + std::copy(args.type_codes, args.type_codes + args.size(), type_codes.begin() + 2); + remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + } + + TVMArgs RecvReplyPacked(int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + return local_session_->RecvReplyPacked(worker_id); + } + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kReceive), worker_id); + remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 2)); + return remote_channels_[node_id - 1]->Recv(); + } + + void AppendHostNDArray(const NDArray& host_array) final { + local_session_->AppendHostNDArray(host_array); + } + + void Shutdown() final { + // local session will be implicitly shutdown by its destructor + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kShutdown), -1); + for (auto& channel : remote_channels_) { + channel->Send(TVMArgs(values, type_codes, 2)); + } + for (auto& socket : remote_sockets_) { + socket.Close(); + } + remote_sockets_.clear(); + remote_channels_.clear(); + if (!socket_.IsClosed()) { + socket_.Close(); + } + Socket::Finalize(); + } + + ~SocketSessionObj() { Shutdown(); } + + static constexpr const char* _type_key = "runtime.disco.SocketSession"; + TVM_DECLARE_FINAL_OBJECT_INFO(SocketSessionObj, BcastSessionObj); + int num_nodes_; + int num_workers_per_node_; + TCPSocket socket_; + std::vector remote_sockets_; + std::vector> remote_channels_; + BcastSession local_session_{nullptr}; +}; + +TVM_REGISTER_OBJECT_TYPE(SocketSessionObj); + +class RemoteSocketSession { + public: + explicit RemoteSocketSession(const String& server_host, int server_port, int num_local_workers) { + socket_.Create(); + socket_.SetKeepAlive(true); + SockAddr server_addr{server_host.c_str(), server_port}; + Socket::Startup(); + if (!socket_.Connect(server_addr)) { + LOG(FATAL) << "Failed to connect to server " << server_addr.AsString() + << ", errno = " << Socket::GetLastErrorCode(); + } + channel_ = std::make_unique(socket_); + TVMArgs metadata = channel_->Recv(); + ICHECK_EQ(metadata.size(), 4); + num_nodes_ = metadata[0].operator int(); + num_workers_per_node_ = metadata[1].operator int(); + num_groups_ = metadata[2].operator int(); + node_id_ = metadata[3].operator int(); + CHECK_GE(num_local_workers, num_workers_per_node_); + InitLocalSession(); + } + + void MainLoop() { + while (true) { + TVMArgs args = channel_->Recv(); + DiscoSocketAction action = static_cast(args[0].operator int()); + int worker_id = args[1].operator int(); + int local_worker_id = worker_id - node_id_ * num_workers_per_node_; + switch (action) { + case DiscoSocketAction::kSend: { + args = TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 2); + if (worker_id == -1) { + local_session_->BroadcastPacked(args); + } else { + local_session_->SendPacked(local_worker_id, args); + } + break; + } + case DiscoSocketAction::kReceive: { + args = local_session_->RecvReplyPacked(local_worker_id); + channel_->Reply(args); + break; + } + case DiscoSocketAction::kShutdown: { + local_session_->Shutdown(); + LOG(INFO) << "Connection closed by remote controller."; + return; + } + default: + LOG(FATAL) << "Invalid action " << static_cast(action); + } + } + } + + ~RemoteSocketSession() { + socket_.Close(); + Socket::Finalize(); + } + + private: + void InitLocalSession() { + const PackedFunc* f_create_local_session = + Registry::Get("runtime.disco.create_socket_session_local_workers"); + local_session_ = ((*f_create_local_session)(num_workers_per_node_)).AsObjectRef(); + + DRef f_init_workers = + local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers"); + local_session_->CallPacked(f_init_workers, num_nodes_, node_id_, num_groups_, + num_workers_per_node_); + } + + TCPSocket socket_; + BcastSession local_session_{nullptr}; + std::unique_ptr channel_; + int num_nodes_{-1}; + int node_id_{-1}; + int num_groups_{-1}; + int num_workers_per_node_{-1}; +}; + +void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, + int num_local_workers) { + RemoteSocketSession proxy(server_host, server_port, num_local_workers); + proxy.MainLoop(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") + .set_body_typed(RemoteSocketSessionEntryPoint); + +Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, + int port) { + auto n = make_object(num_nodes, num_workers_per_node, num_groups, host, port); + return Session(n); +} + +TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); + +TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") + .set_body_typed([](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { + LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " + << num_workers_per_node << " workers per node, and " << num_groups << " groups."; + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + worker->num_groups = num_groups; + worker->worker_id = worker->worker_id + node_id * num_workers_per_node; + worker->num_workers = num_nodes * num_workers_per_node; + }); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h new file mode 100644 index 000000000000..3b78c3e5c187 --- /dev/null +++ b/src/runtime/disco/message_queue.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ +#define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ + +#include + +#include + +#include "./protocol.h" + +namespace tvm { +namespace runtime { + +class DiscoStreamMessageQueue : private dmlc::Stream, + private DiscoProtocol { + public: + explicit DiscoStreamMessageQueue(Stream* stream) : stream_(stream) {} + + ~DiscoStreamMessageQueue() = default; + + void Send(const TVMArgs& args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); + CommitSendAndNotifyEnqueue(); + } + + TVMArgs Recv() { + bool is_implicit_shutdown = DequeueNextPacket(); + TVMValue* values = nullptr; + int* type_codes = nullptr; + int num_args = 0; + + if (is_implicit_shutdown) { + num_args = 2; + values = ArenaAlloc(num_args); + type_codes = ArenaAlloc(num_args); + TVMArgsSetter setter(values, type_codes); + setter(0, static_cast(DiscoAction::kShutDown)); + setter(1, 0); + } else { + RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); + } + return TVMArgs(values, type_codes, num_args); + } + + protected: + void CommitSendAndNotifyEnqueue() { + stream_->Write(write_buffer_.data(), write_buffer_.size()); + write_buffer_.clear(); + } + + /* \brief Read next packet and reset unpacker + * + * Read the next packet into `read_buffer_`, releasing all arena + * allocations performed by the unpacker and resetting the unpacker + * to its initial state. + * + * \return A boolean value. If true, this packet should be treated + * equivalently to a `DiscoAction::kShutdown` event. If false, + * this packet should be unpacked. + */ + bool DequeueNextPacket() { + uint64_t packet_nbytes = 0; + int read_size = stream_->Read(&packet_nbytes, sizeof(packet_nbytes)); + if (read_size == 0) { + // Special case, connection dropped between packets. Treat as a + // request to shutdown. + return true; + } + + ICHECK_EQ(read_size, sizeof(packet_nbytes)) + << "Stream closed without proper shutdown. Please make sure to explicitly call " + "`Session::Shutdown`"; + read_buffer_.resize(packet_nbytes); + read_size = stream_->Read(read_buffer_.data(), packet_nbytes); + ICHECK_EQ(read_size, packet_nbytes) + << "Stream closed without proper shutdown. Please make sure to explicitly call " + "`Session::Shutdown`"; + read_offset_ = 0; + this->RecycleAll(); + RPCCode code = RPCCode::kReturn; + this->Read(&code); + return false; + } + + size_t Read(void* data, size_t size) final { + std::memcpy(data, read_buffer_.data() + read_offset_, size); + read_offset_ += size; + ICHECK_LE(read_offset_, read_buffer_.size()); + return size; + } + + size_t Write(const void* data, size_t size) final { + size_t cur_size = write_buffer_.size(); + write_buffer_.resize(cur_size + size); + std::memcpy(write_buffer_.data() + cur_size, data, size); + return size; + } + + using dmlc::Stream::Read; + using dmlc::Stream::ReadArray; + using dmlc::Stream::Write; + using dmlc::Stream::WriteArray; + friend struct RPCReference; + friend struct DiscoProtocol; + + // The read/write buffer will only be accessed by the producer thread. + std::string write_buffer_; + std::string read_buffer_; + size_t read_offset_ = 0; + dmlc::Stream* stream_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 35e8fd06b309..d35fc911c692 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -86,7 +86,8 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { << "and has not been destructed"; // Step up local context of NCCL - int device_id = device_ids[worker->worker_id]; + int group_size = worker->num_workers / worker->num_groups; + int device_id = device_ids[worker->local_worker_id]; SetDevice(device_id); #if TVM_NCCL_RCCL_SWITCH == 0 StreamCreate(&ctx->default_stream); @@ -99,7 +100,6 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { // Initialize the communicator ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); - int group_size = worker->num_workers / worker->num_groups; NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, worker->worker_id % group_size, &ctx->group_comm, NULL)); diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 7c8d0796dd81..161c3f6e0408 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -31,114 +31,19 @@ #include "../minrpc/rpc_reference.h" #include "./bcast_session.h" #include "./disco_worker_thread.h" +#include "./message_queue.h" #include "./protocol.h" namespace tvm { namespace runtime { -class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol { - public: - explicit DiscoPipeMessageQueue(int64_t handle) : pipe_(handle) {} - - ~DiscoPipeMessageQueue() = default; - - void Send(const TVMArgs& args) { - RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); - CommitSendAndNotifyEnqueue(); - } - - TVMArgs Recv() { - bool is_implicit_shutdown = DequeueNextPacket(); - TVMValue* values = nullptr; - int* type_codes = nullptr; - int num_args = 0; - - if (is_implicit_shutdown) { - num_args = 2; - values = ArenaAlloc(num_args); - type_codes = ArenaAlloc(num_args); - TVMArgsSetter setter(values, type_codes); - setter(0, static_cast(DiscoAction::kShutDown)); - setter(1, 0); - } else { - RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); - } - return TVMArgs(values, type_codes, num_args); - } - - protected: - void CommitSendAndNotifyEnqueue() { - pipe_.Write(write_buffer_.data(), write_buffer_.size()); - write_buffer_.clear(); - } - - /* \brief Read next packet and reset unpacker - * - * Read the next packet into `read_buffer_`, releasing all arena - * allocations performed by the unpacker and resetting the unpacker - * to its initial state. - * - * \return A boolean value. If true, this packet should be treated - * equivalently to a `DiscoAction::kShutdown` event. If false, - * this packet should be unpacked. - */ - bool DequeueNextPacket() { - uint64_t packet_nbytes = 0; - int read_size = pipe_.Read(&packet_nbytes, sizeof(packet_nbytes)); - if (read_size == 0) { - // Special case, connection dropped between packets. Treat as a - // request to shutdown. - return true; - } - - ICHECK_EQ(read_size, sizeof(packet_nbytes)) - << "Pipe closed without proper shutdown. Please make sure to explicitly call " - "`Session::Shutdown`"; - read_buffer_.resize(packet_nbytes); - read_size = pipe_.Read(read_buffer_.data(), packet_nbytes); - ICHECK_EQ(read_size, packet_nbytes) - << "Pipe closed without proper shutdown. Please make sure to explicitly call " - "`Session::Shutdown`"; - read_offset_ = 0; - this->RecycleAll(); - RPCCode code = RPCCode::kReturn; - this->Read(&code); - return false; - } - - size_t Read(void* data, size_t size) final { - std::memcpy(data, read_buffer_.data() + read_offset_, size); - read_offset_ += size; - ICHECK_LE(read_offset_, read_buffer_.size()); - return size; - } - - size_t Write(const void* data, size_t size) final { - size_t cur_size = write_buffer_.size(); - write_buffer_.resize(cur_size + size); - std::memcpy(write_buffer_.data() + cur_size, data, size); - return size; - } - - using dmlc::Stream::Read; - using dmlc::Stream::ReadArray; - using dmlc::Stream::Write; - using dmlc::Stream::WriteArray; - friend struct RPCReference; - friend struct DiscoProtocol; - - // The read/write buffer will only be accessed by the producer thread. - std::string write_buffer_; - std::string read_buffer_; - size_t read_offset_ = 0; - support::Pipe pipe_; -}; - class DiscoProcessChannel final : public DiscoChannel { public: DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t worker_to_controler_fd) - : controler_to_worker_(controler_to_worker_fd), - worker_to_controler_(worker_to_controler_fd) {} + : controller_to_worker_pipe_(controler_to_worker_fd), + worker_to_controller_pipe_(worker_to_controler_fd), + controler_to_worker_(&controller_to_worker_pipe_), + worker_to_controler_(&worker_to_controller_pipe_) {} DiscoProcessChannel(DiscoProcessChannel&& other) = delete; DiscoProcessChannel(const DiscoProcessChannel& other) = delete; @@ -148,8 +53,10 @@ class DiscoProcessChannel final : public DiscoChannel { void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } TVMArgs RecvReply() { return worker_to_controler_.Recv(); } - DiscoPipeMessageQueue controler_to_worker_; - DiscoPipeMessageQueue worker_to_controler_; + support::Pipe controller_to_worker_pipe_; + support::Pipe worker_to_controller_pipe_; + DiscoStreamMessageQueue controler_to_worker_; + DiscoStreamMessageQueue worker_to_controler_; }; class ProcessSessionObj final : public BcastSessionObj { @@ -226,7 +133,7 @@ class ProcessSessionObj final : public BcastSessionObj { int type_codes[4]; PackArgs(values, type_codes, static_cast(DiscoAction::kDebugSetRegister), reg_id, worker_id, value); - workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4)); + SendPacked(worker_id, TVMArgs(values, type_codes, 4)); } TVMRetValue result; TVMArgs args = this->RecvReplyPacked(worker_id); @@ -241,6 +148,14 @@ class ProcessSessionObj final : public BcastSessionObj { } } + void SendPacked(int worker_id, const TVMArgs& args) final { + if (worker_id == 0) { + worker_0_->channel->Send(args); + } else { + workers_.at(worker_id - 1)->Send(args); + } + } + TVMArgs RecvReplyPacked(int worker_id) final { if (worker_id == 0) { return worker_0_->channel->RecvReply(); @@ -248,6 +163,13 @@ class ProcessSessionObj final : public BcastSessionObj { return this->workers_.at(worker_id - 1)->RecvReply(); } + DiscoChannel* GetWorkerChannel(int worker_id) { + if (worker_id == 0) { + return worker_0_->channel.get(); + } + return workers_.at(worker_id - 1).get(); + } + PackedFunc process_pool_; std::unique_ptr worker_0_; std::vector> workers_; diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index cc9a311a6b3f..bf6b6107e122 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -173,6 +173,10 @@ class ThreadedSessionObj final : public BcastSessionObj { } } + void SendPacked(int worker_id, const TVMArgs& args) final { + this->workers_.at(worker_id).channel->Send(args); + } + TVMArgs RecvReplyPacked(int worker_id) final { return this->workers_.at(worker_id).channel->RecvReply(); } diff --git a/src/support/socket.h b/src/support/socket.h index ac13cd3f2d35..032cf257c045 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -370,7 +370,7 @@ class Socket { /*! * \brief a wrapper of TCP socket that hopefully be cross platform */ -class TCPSocket : public Socket { +class TCPSocket : public Socket, public dmlc::Stream { public: TCPSocket() : Socket(INVALID_SOCKET) {} /*! @@ -552,6 +552,10 @@ class TCPSocket : public Socket { ICHECK_EQ(RecvAll(&data[0], datalen), datalen); return data; } + + size_t Read(void* data, size_t size) final { return Recv(data, size); } + + size_t Write(const void* data, size_t size) final { return Send(data, size); } }; /*! \brief helper data structure to perform poll */ diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 837b3a14f271..38aa757bf8f1 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -20,6 +20,9 @@ import numpy as np import pytest +import subprocess +import threading +import sys import tvm import tvm.testing @@ -29,7 +32,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.exec import disco_worker as _ +from tvm.exec import disco_worker as _ # pylint: disable=unused-import def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): @@ -46,7 +49,75 @@ def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype): return host_array.numpy() -_all_session_kinds = [di.ThreadedSession, di.ProcessSession] +_SOCKET_SESSION_TESTER = None + + +def get_free_port(): + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class SocketSessionTester: + def __init__(self, num_workers): + num_nodes = 2 + num_groups = 1 + assert num_workers % num_nodes == 0 + num_workers_per_node = num_workers // num_nodes + server_host = "localhost" + server_port = get_free_port() + self.sess = None + + def start_server(): + self.sess = di.SocketSession( + num_nodes, num_workers_per_node, num_groups, server_host, server_port + ) + + thread = threading.Thread(target=start_server) + thread.start() + + cmd = "tvm.exec.disco_remote_socket_session" + self.remote_nodes = [] + for _ in range(num_nodes - 1): + self.remote_nodes.append( + subprocess.Popen( + [ + "python3", + "-m", + cmd, + server_host, + str(server_port), + str(num_workers_per_node), + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + ) + + thread.join() + + def __del__(self): + for node in self.remote_nodes: + node.kill() + if self.sess is not None: + self.sess.shutdown() + del self.sess + + +def create_socket_session(num_workers): + global _SOCKET_SESSION_TESTER + if _SOCKET_SESSION_TESTER is not None: + del _SOCKET_SESSION_TESTER + _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers) + assert _SOCKET_SESSION_TESTER.sess is not None + return _SOCKET_SESSION_TESTER.sess + + +_all_session_kinds = [di.ThreadedSession, di.ProcessSession, create_socket_session] @pytest.mark.parametrize("session_kind", _all_session_kinds) @@ -157,6 +228,11 @@ def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="floa y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) np.testing.assert_equal(y_nd, y_np) + # sync all workers to make sure the temporary files are cleaned up after all workers + # finish the execution + for i in range(num_workers): + sess._sync_worker(i) + @pytest.mark.parametrize("session_kind", _all_session_kinds) def test_vm_multi_func(session_kind): @@ -220,10 +296,17 @@ def transpose_2( np.testing.assert_equal(y_nd, y_np) np.testing.assert_equal(z_nd, x_np) + # sync all workers to make sure the temporary files are cleaned up after all workers + # finish the execution + for i in range(num_workers): + sess._sync_worker(i) + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("num_workers", [1, 2, 4]) def test_num_workers(session_kind, num_workers): + if session_kind == create_socket_session and num_workers < 2: + return sess = session_kind(num_workers=num_workers) assert sess.num_workers == num_workers From df33d73ceca1d0c4ba280cfbcce504b232111d4c Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Subramaniam Date: Fri, 26 Jul 2024 19:08:27 +0530 Subject: [PATCH 031/202] [LLVM] Fix for getHostCPUFeatures API change (#17199) This patch fixes a minor API change in latest LLVM. --- src/target/llvm/codegen_llvm.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6098a3f32f0d..4c5bea8c9b4b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2315,6 +2315,16 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> st TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") .set_body_typed([]() -> Map { +#if TVM_LLVM_VERSION >= 200 + Map ret; + auto features = llvm::sys::getHostCPUFeatures(); + for (auto it = features.begin(); it != features.end(); ++it) { + std::string name = it->getKey().str(); + bool value = it->getValue(); + ret.Set(name, IntImm(DataType::Bool(), value)); + } + return ret; +#else llvm::StringMap features; if (llvm::sys::getHostCPUFeatures(features)) { Map ret; @@ -2325,6 +2335,7 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") } return ret; } +#endif LOG(WARNING) << "Current version of LLVM does not support feature detection on your CPU"; return {}; }); From 4330c110550242571da017a1b15ae0b765723ae8 Mon Sep 17 00:00:00 2001 From: FranckQC <89943638+FranckQC@users.noreply.github.com> Date: Sat, 27 Jul 2024 23:32:22 -0500 Subject: [PATCH 032/202] [Hexagon] Fix LWP assembly handler (predicate register) (#17204) * Fix LWP assembly handler (predicate register) (#2216) This solved the issue with LWP that appears with maxpool. The problem was that the LWP handler was forgetting to save p0 (used by the handler). This predicate register needs to be saved too, just like r0-r5, as it had been decided that it was the responsibility of the handler to save everything (even these theoretically caller-saved registers). Said differently, since it had been decided that calling the LWP handler would not follow the normal ABI, and that the LWP handler would save everything it touches (even normally caller-saved registers like r0-r15 and p0-3), then it absolutely needs to save the predicate registers too (in particular p0, which was causing the issue). The issue appeared only with maxpool because it's the only one that had a state saved in p0 before calling the LWP handler. And this call destroyed the content of what it had saved, making it subsequently branch to different portions of the code. Fix: Allocate 32 bytes (instead of 24 previously), in order to save p3:0, and I save those at the bottom of the stack. Restore it at the end of the LWP handler. * Remove training spaces --------- Co-authored-by: Slama, Franck --- src/runtime/hexagon/profiler/lwp_handler.S | 25 +++++++++++++++------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/runtime/hexagon/profiler/lwp_handler.S b/src/runtime/hexagon/profiler/lwp_handler.S index 611c0713111a..8cd02dd828f4 100644 --- a/src/runtime/hexagon/profiler/lwp_handler.S +++ b/src/runtime/hexagon/profiler/lwp_handler.S @@ -50,12 +50,17 @@ handler itself. .falign .type lwp_handler,@function lwp_handler: - { allocframe(#24) // Allocate 24 bytes on the stack to save R0-R5 registers + { + allocframe(#32) // Allocate 32 bytes on the stack to save R0-R5 registers (6*4bytes) and P0-P3 (4*1byte) + 4 unused bytes as the stack has to be 8-bytes aligned memd(r29+#-16) = r5:4 // Save R5,R4 + r5 = p3:0 // We will save P3:0 but we need an intermediate usual register (R5) that has already been saved + } + { + memd(r29+#16) = r3:2 // Save R3,R2 + memd(r29+#8) = r1:0 // Save R1, R0 } { - memd(r29+#8) = r3:2 // Save R3,R2 - memd(r29+#0) = r1:0 // Save R1, R0 + memw(r29+#0) = r5 // Save P3:0 (via R5) r2 = add(pc,##_GLOBAL_OFFSET_TABLE_@PCREL) // Get GOT address } { @@ -102,14 +107,18 @@ lwp_handler: memw(r5+#8) = r0 // Save lower 32 bits } .falign -.LBB0_3: +.LBB0_3: // Restore the registers from the stack + { + r1 = memw(r29+#0) // We will restore P3:0 but need an intermediate usual register (R1) that hasn't already been restored + r5:4 = memd(r29+#24) // Restore R5:4 + } { - r5:4 = memd(r29+#16) // Restore the registers from the stack - r3:2 = memd(r29+#8) + r3:2 = memd(r29+#16) // Restore R3:2 + p3:0 = r1 // Restore P3:0 (via R1, not yet restored) } { - r1:0 = memd(r29+#0) - dealloc_return // Deallocate the stack and return + r1:0 = memd(r29+#8) // Restore R1:0 + dealloc_return // Deallocate the stack and return } .Lfunc_end0: .size lwp_handler, .Lfunc_end0-lwp_handler From f62445cdd96a415d332585aa9702eaf1df3cf972 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 28 Jul 2024 13:57:09 -0700 Subject: [PATCH 033/202] [Relax] Disable fusion for fetching from the packed params in FuseOps (#17198) * [Relax] Disable fusion for fetching from the packed params in FuseOps The order of bindings in the fusion result is determined by the first binding in each partition group. When the packed param tuple is used, the function usually begins with a numbers of `TupleGetItem` to unpack the param tuple. Previously `TupleGetItem` is treated as `kInjective`, this causes any operation that relies purely on these params to be moved to the beginning of the function and increases the memory usage of the intermediate results. * lint --- src/relax/transform/fuse_ops.cc | 19 +++++++- tests/python/relax/test_transform_fuse_ops.py | 48 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 6030a28d93b6..e791aeab061d 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -147,6 +147,12 @@ class GraphCreator : public ExprVisitor { SetNodePattern(param_node, OpPatternKind::kOpaque); AddToPostDFSOrder(param_node, param.get()); } + if (auto opt_num_input = func->GetAttr(attr::kNumInput)) { + for (int i = static_cast(opt_num_input.value()->value); + i < static_cast(func->params.size()); ++i) { + input_params_.insert(func->params[i].get()); + } + } ExprVisitor::VisitExpr_(func); } @@ -224,8 +230,15 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph::Node* binding_var_node) { ICHECK_NOTNULL(binding_var_node); - SetNodePattern(binding_var_node, OpPatternKind::kInjective); - VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + auto pattern = OpPatternKind::kInjective; + if (input_params_.count(tuple_item->tuple.as())) { + // TupleGetItem for fetching the parameter from the packed param tuple is treated as opaque + // and won't be fused. This prevents the usage of packed param tuple changes the order of the + // fusion result as the function usually begins with fetching the parameters. + pattern = OpPatternKind::kOpaque; + } + SetNodePattern(binding_var_node, pattern); + VisitLeaf(tuple_item->tuple, binding_var_node, pattern); } void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { @@ -354,6 +367,8 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph graph_; /*! \brief The graph nodes whose patterns are set */ std::unordered_set initialized_nodes_; + /*! \brief The model params in the function input */ + std::unordered_set input_params_; }; /*! diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 3cd608d8ee8f..17bf58613294 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1642,5 +1642,53 @@ def main( _check(Module, Expected) +def test_packed_params(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1 in T.grid(T.int64(16), T.int64(16)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(lv[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Cast("float32", lv[v_i0, v_i1]) + + @T.prim_func(private=True) + def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)): + with T.block("T_matmul"): + v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) + T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1]) + T.writes(T_matmul[v_ax0, v_ax1]) + with T.init(): + T_matmul[v_ax0, v_ax1] = T.float32(0) + T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + x[v_ax0, v_k] * lv2[v_k, v_ax1] + + @R.function + def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: R.Tuple(R.Tensor((16, 16), dtype="float16"), R.Tensor((16, 16), dtype="float16"))) -> R.Tensor((16, 16), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv: R.Tensor((16, 16), dtype="float16") = packed_params[0] + lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1] + lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_sinfo=R.Tensor((16, 16), dtype="float32")) + gv: R.Tensor((16, 16), dtype="float32") = lv5 + R.output(gv) + return gv + # fmt: on + + Expected = Before + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main() From 2c9af0f500c04383aa7220ab2c9220a608f75cbf Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Mon, 29 Jul 2024 08:17:55 -0400 Subject: [PATCH 034/202] [Runtime] Allow aborting fetchNDArray through AbortSignal (#17208) [Runtime] Allow aborting fetchNDArray --- web/src/artifact_cache.ts | 11 ++++++----- web/src/runtime.ts | 13 +++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index f833df1be523..9690ed3320b9 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -58,13 +58,14 @@ export interface ArtifactCacheTemplate { * * @param url: The url to the data to be cached. * @param storetype: Only applies to `ArtifactIndexedDBCache`. Since `indexedDB` stores the actual + * @param signal: An optional AbortSignal to abort data retrival * data rather than a request, we specify `storagetype`. There are two options: * 1. "json": IndexedDB stores `fetch(url).json()` * 2. "arraybuffer": IndexedDB stores `fetch(url).arrayBuffer()` * * @note This is an async function. */ - addToCache(url: string, storetype?: string): Promise; + addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise; /** * check if cache has all keys in Cache @@ -126,8 +127,8 @@ export class ArtifactCache implements ArtifactCacheTemplate { } // eslint-disable-next-line @typescript-eslint/no-unused-vars - async addToCache(url: string, storetype?: string) { - const request = new Request(url); + async addToCache(url: string, storetype?: string, signal?: AbortSignal) { + const request = new Request(url, signal ? { signal } : undefined); if (this.cache === undefined) { this.cache = await caches.open(this.scope); } @@ -282,7 +283,7 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { }); } - async addToCache(url: string, storetype?: string): Promise { + async addToCache(url: string, storetype?: string, signal?: AbortSignal): Promise { await this.initDB(); // await the initDB process // If already cached, nothing to do const isInDB = await this.isUrlInDB(url); @@ -290,7 +291,7 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { return; } try { - const response = await fetch(url); + const response = await fetch(url, signal ? { signal } : undefined); if (!response.ok) { throw new Error('Network response was not ok'); } diff --git a/web/src/runtime.ts b/web/src/runtime.ts index fd7bcc6ab23b..d71c98e7d1bc 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1444,13 +1444,15 @@ export class Instance implements Disposable { * @param device The device to be fetched to. * @param cacheScope The scope identifier of the cache * @param cacheType The type of the cache: "cache" or "indexedDB" + * @param signal An optional AbortSignal to abort the fetch * @returns The meta data */ async fetchNDArrayCache( ndarrayCacheUrl: string, device: DLDevice, cacheScope = "tvmjs", - cacheType = "cache" + cacheType = "cache", + signal?: AbortSignal, ): Promise { let artifactCache: ArtifactCacheTemplate; if (cacheType === undefined || cacheType.toLowerCase() === "cache") { @@ -1465,7 +1467,8 @@ export class Instance implements Disposable { const list = await artifactCache.fetchWithCache(jsonUrl, "json"); await this.fetchNDArrayCacheInternal( ndarrayCacheUrl, - list["records"] as Array, device, artifactCache); + list["records"] as Array, device, artifactCache, + signal); this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; } @@ -1477,12 +1480,14 @@ export class Instance implements Disposable { * @param list The list of array data. * @param device The device to store the data to. * @param artifactCache The artifact cache + * @param signal An optional AbortSignal to abort the fetch */ private async fetchNDArrayCacheInternal( ndarrayCacheUrl: string, list: Array, device: DLDevice, - artifactCache: ArtifactCacheTemplate + artifactCache: ArtifactCacheTemplate, + signal?: AbortSignal, ) { const perf = compact.getPerformance(); const tstart = perf.now(); @@ -1537,7 +1542,7 @@ export class Instance implements Disposable { const shard = list[i]; const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; try { - await artifactCache.addToCache(dataUrl, "arraybuffer"); + await artifactCache.addToCache(dataUrl, "arraybuffer", signal); } catch (err) { this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); throw err; From 9e88018c3a56ab378dd11410a662ed5c3da1f4df Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Jul 2024 10:45:57 -0500 Subject: [PATCH 035/202] [CI] Update dummy-variable regex for pylint (#17206) Prior to this commit, the regex used for pylint to identify dummy variables would correctly identify variables that start with an underscore (e.g. `_scale`), unless they have an underscore elsewhere in the name (e.g. `_scale_factor`). This leads to false positives from pylint for unused variables, as prefixing a variable with an underscore should mark a variable as intentionally unused. This commit updates the regex in TVM's `pylintrc` to match the current default value for `dummy-variables-rgx`, to allow unused variables to be named with a leading underscore, even if they also contain another underscore. --- tests/lint/pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lint/pylintrc b/tests/lint/pylintrc index 3b5e14d15bb0..90900b9e005a 100644 --- a/tests/lint/pylintrc +++ b/tests/lint/pylintrc @@ -252,7 +252,7 @@ init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). -dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. From 16f88223c6782ead92928d64bb4a3567cdb71419 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Jul 2024 08:37:57 -0500 Subject: [PATCH 036/202] [Transform][Relax] Handle `is_group` argument in IPC AllReduce (#17201) * [Transform][Relax] Handle `is_group` argument in IPC AllReduce The `relax.transform.IPCAllReduceRewrite` pass rewrites calls to `"runtime.disco.allreduce"` to instead call an optimized `"runtime.disco.cuda_ipc.custom_allreduce"` version. When the legalization of `R.ccl.allreduce` was updated in https://github.com/apache/tvm/pull/17180 to provide an `in_group` argument, the `IPCAllReduceRewrite` pass was not updated. This commit updates the `IPCAllReduceRewrite` to be handle the additional `in_group` argument. * lint fix * lint fix --- .../tvm/relax/transform/ipc_allreduce_rewrite.py | 10 +++++++--- .../test_transform_ipc_allreduce_rewrite.py | 16 ++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/transform/ipc_allreduce_rewrite.py b/python/tvm/relax/transform/ipc_allreduce_rewrite.py index df40181cb981..de5c22863403 100644 --- a/python/tvm/relax/transform/ipc_allreduce_rewrite.py +++ b/python/tvm/relax/transform/ipc_allreduce_rewrite.py @@ -97,8 +97,8 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re # Return if the call is not a summation all-reduce. return - assert len(call.args) == 3 - allreduce_input = call.args[0] + assert len(call.args) == 4 + allreduce_input, _strategy, _ingroup, allreduce_output = call.args alloc_tensor = self.alloc_map.get(allreduce_input, None) if alloc_tensor is None or alloc_tensor.args[3].value != "global": # Return if the allocation of all-reduce input is not recorded, @@ -113,9 +113,13 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re alloc_tensor.args[2], relax.StringImm("ipc_memory"), ) + self.binding_replacement_map[call] = relax.Call( relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"), - args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]], + # The "cuda_ipc.custom_allreduce" implementation does not + # yet support num_groups>1, and therefore does not use the + # `in_group` argument. + [allreduce_input, relax.PrimValue(self.allreduce_strategy), allreduce_output], ) diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index f14953122ee3..da85423aafd7 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -37,7 +37,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(True), alloc1 + ) return alloc1 @I.ir_module @@ -85,7 +87,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(False), alloc1 + ) return alloc1 @I.ir_module @@ -137,7 +141,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - _: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([1]), alloc1) + _: R.Object = R.call_packed( + "runtime.disco.allreduce", lv1, R.shape([1]), R.prim_value(True), alloc1 + ) return alloc1 allreduce_strategy = 1 @@ -146,6 +152,4 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore if __name__ == "__main__": - test_ipc_allreduce_rewrite() - test_ipc_allreduce_spread_along_reshape() - test_ipc_allreduce_skip_reducer_other_than_sum() + tvm.testing.main() From 538343f7f0989c039ff0ba0fedcd5cef6f151c8e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Jul 2024 10:35:13 -0500 Subject: [PATCH 037/202] [CI] Reduce logging level when checking if docker image exists (#17221) Prior to this commit, the `image_exists` utility in `determine_docker_images.py` logged the full response for success, and the full HTTP error if an exception is caught. However, this is the expected behavior when loading a docker image from `tlcpackstaging`, such as the current images tagged with `20240428-060115-0b09ed018`. Logging this fallback as an error makes it difficult to find the first actual error that occurred in CI. This commit updates these logging statments `logging.info` and `logging.exception` to instead use `logging.debug`. --- ci/scripts/jenkins/determine_docker_images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/scripts/jenkins/determine_docker_images.py b/ci/scripts/jenkins/determine_docker_images.py index 41003958dd61..7e20c4f1384a 100755 --- a/ci/scripts/jenkins/determine_docker_images.py +++ b/ci/scripts/jenkins/determine_docker_images.py @@ -62,11 +62,11 @@ def image_exists(spec: str) -> bool: name, tag = spec.split(":") try: r = docker_api(f"repositories/{name}/tags/{tag}") - logging.info(f"Image exists, got response: {json.dumps(r, indent=2)}") + logging.debug(f"Image exists, got response: {json.dumps(r, indent=2)}") return True except urllib.error.HTTPError as e: # Image was not found - logging.exception(e) + logging.debug(e) return False From 8680c39c33b41b3ce18d3c6562a89a9b8355bb50 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Jul 2024 14:16:14 -0500 Subject: [PATCH 038/202] [Relax] Handle presence of R.call_tir in MergeCompositeFunctions (#17220) Prior to this commit, use of `R.call_tir` in the input to `MergeCompositeFunctions` would result in a segfault, when attempting to determine the `Group*` that contains the `relax::GlobalVar` of the callee. This commit updates `MergeCompositeFunctions` to check for `relax::GlobalVar` and `relax::Tuple` instances. Closes https://github.com/apache/tvm/issues/17120 --- .../transform/merge_composite_functions.cc | 22 +++- ...est_transform_merge_composite_functions.py | 119 ++++++++++++++++++ 2 files changed, 138 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 0dd14f5bb1af..0a3c4ff0a193 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -234,19 +234,35 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { void UpdateGroupDependencies(Group* group, const Array& args) { Group* group_root = group->FindRoot(); - for (const auto& arg : args) { - auto arg_group_root = memo_[arg]->FindRoot(); + std::function visit_expr = [&](Expr expr) { + if (expr.as()) return; + if (auto tuple = expr.as()) { + for (const auto& field : tuple->fields) { + visit_expr(field); + } + return; + } + + ICHECK(memo_.count(expr)) << "Could not find memo-ized group for expression of type " + << expr->GetTypeKey(); + auto arg_group_root = memo_[expr]->FindRoot(); + if (arg_group_root == group_root) { // If arg and the current node are in the same group, // there is nothing to update. - continue; + return; } + // Add the group of arg as dependency group_deps_[group_root].insert(arg_group_root); // Propagate dependencies of arg for (auto dep : group_deps_[arg_group_root]) { group_deps_[group_root].insert(dep); } + }; + + for (const auto& arg : args) { + visit_expr(arg); } } diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index cff832a21ff9..27537edd9e5f 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -20,6 +20,7 @@ from tvm import relax from tvm.script import relax as R from tvm.script import ir as I +from tvm.script import tir as T @tvm.script.ir_module @@ -1106,5 +1107,123 @@ def main( check(Module, Expected) +def test_handle_existence_of_call_tir(): + """MergeCompositeFunctions should accept R.call_tir as input + + No merging is required in this case, since the two composite + functions have `R.call_tir` between them. This is a regression + test, as previously the `Tuple` used to express of `R.call_tir` + caused a segfault. + + """ + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"): + cls = Before + with R.dataflow(): + B = cls.fused_relax_nn_relu(A) + C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32")) + D = cls.fused_relax_nn_gelu(C) + R.output(D) + return D + + @R.function(private=True) + def fused_relax_nn_relu( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + with R.dataflow(): + Output = R.nn.relu(Input) + R.output(Output) + return Output + + @T.prim_func(private=True) + def relu( + Input: T.Buffer(T.int64(10), "float32"), + Output: T.Buffer(T.int64(10), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(10)): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + Output[vi] = T.max(Input[vi], T.float32(0)) + + @R.function(private=True) + def fused_relax_nn_gelu( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + with R.dataflow(): + Output = R.nn.gelu(Input) + R.output(Output) + return Output + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"): + cls = Expected + with R.dataflow(): + B = cls.fused_relax_nn_relu1_compiler_A(A) + C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32")) + D = cls.fused_relax_nn_gelu1_compiler_A(C) + R.output(D) + return D + + @R.function + def fused_relax_nn_relu1_compiler_A( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Codegen": "compiler_A"}) + + @R.function + def composite_lambda( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Composite": "compiler_A.relu"}) + with R.dataflow(): + Output = R.nn.relu(Input) + R.output(Output) + return Output + + Output = composite_lambda(Input) + return Output + + @T.prim_func(private=True) + def relu( + Input: T.Buffer(T.int64(10), "float32"), + Output: T.Buffer(T.int64(10), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(10)): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + Output[vi] = T.max(Input[vi], T.float32(0)) + + @R.function + def fused_relax_nn_gelu1_compiler_A( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Codegen": "compiler_A"}) + + @R.function + def composite_lambda( + Input: R.Tensor([10], dtype="float32") + ) -> R.Tensor([10], dtype="float32"): + R.func_attr({"Composite": "compiler_A.gelu"}) + with R.dataflow(): + Output = R.nn.gelu(Input) + R.output(Output) + return Output + + Output = composite_lambda(Input) + return Output + + After = relax.transform.MergeCompositeFunctions()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": pytest.main([__file__]) From 24cd93df8b70dab4791cd383e542e9f697a3af0b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 1 Aug 2024 08:20:35 -0500 Subject: [PATCH 039/202] [Relax] Fix segfault in rewrite_bindings for MatchCast node (#17226) Prior to this commit, the `tvm.relax.dpl.rewrite_bindings` utility would segfault if its input contained a `DataflowBlock` whose first binding was a `MatchCast`. The root cause is use of an unintialized `const VarNode* cur_user_;` when collecting the variable usage. This variable is only initialized for `VarBinding` nodes, and may be used uninitialized if a `MatchCast` node is encountered before the first `VarBinding`. This uninitialized value is later dereferenced during while pattern-matching, causing a segfault. This commit provides a default value of `nullptr` for `MatcherUseDefAnalysis::cur_user_`, preventing the segfault. --- src/relax/ir/dataflow_block_rewriter.cc | 2 +- tests/python/relax/test_dataflow_pattern.py | 109 +++++++++++++------- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index fb08dfe96a17..88efad86cfdc 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -49,7 +49,7 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { // caller -> callee table. std::map> caller2callees; - const VarNode* cur_user_; + const VarNode* cur_user_ = nullptr; void VisitBinding_(const VarBindingNode* binding) override { // init diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index f67b0530ca87..03a3beb2f27e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1053,9 +1053,17 @@ def main( assert ctx.match_dfb(dfb) is None -def get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 -): +def get_qkv_proj_rewriter(): + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + def qkv_proj_rewriter(matchings, _): inp = matchings[inp_pat] Q_weight = matchings[Q_weight_pat] @@ -1071,7 +1079,7 @@ def qkv_proj_rewriter(matchings, _): return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} - return qkv_proj_rewriter + return ctx, qkv_proj_rewriter def test_combine_matmul_twice(): @@ -1123,21 +1131,63 @@ def expected( R.output(out) return out - with PatternContext() as ctx: - inp_pat = wildcard() - Q_weight_pat = wildcard() - K_weight_pat = wildcard() - V_weight_pat = wildcard() + ctx, rewriter = get_qkv_proj_rewriter() + rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) + tvm.ir.assert_structural_equal(rewritten, expected) - matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) - matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) - matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) - rewriter = get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 - ) - rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) - tvm.ir.assert_structural_equal(rewritten, expected) +def test_dataflow_may_start_with_match_cast(): + """Inputs to rewrite_bindings may contain R.match_cast + + This is a regression test. In previous implementations, applying + `rewrite_bindings` when `R.match_cast` is the first binding of a + `R.dataflow` block would cause a segfault. + + """ + + @R.function(private=True) + def before( + x_untyped: R.Tensor, + w0_untyped: R.Tensor, + w1_untyped: R.Tensor, + w2_untyped: R.Tensor, + ): + with R.dataflow(): + x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32")) + w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32")) + w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32")) + w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32")) + out_0 = R.matmul(x, w0) + out_1 = R.matmul(x, w1) + out_2 = R.matmul(x, w2) + out = (out_0, out_1, out_2) + R.output(out) + return out + + @R.function(private=True) + def expected( + x_untyped: R.Tensor, + w0_untyped: R.Tensor, + w1_untyped: R.Tensor, + w2_untyped: R.Tensor, + ): + with R.dataflow(): + x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32")) + w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32")) + w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32")) + w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32")) + w_concat = R.concat((w0, w1, w2), axis=1) + out_concat = R.matmul(x, w_concat) + out_0 = R.strided_slice(out_concat, axes=[2], begin=[0], end=[640]) + out_1 = R.strided_slice(out_concat, axes=[2], begin=[640], end=[1280]) + out_2 = R.strided_slice(out_concat, axes=[2], begin=[1280], end=[1920]) + out = (out_0, out_1, out_2) + R.output(out) + return out + + ctx, rewriter = get_qkv_proj_rewriter() + rewritten = rewrite_bindings(ctx, rewriter, before) + tvm.ir.assert_structural_equal(rewritten, expected) def test_combine_matmul_emit_order(): @@ -1181,27 +1231,16 @@ def expected( R.output(out) return out - with PatternContext() as ctx: - inp_pat = wildcard() - Q_weight_pat = wildcard() - K_weight_pat = wildcard() - V_weight_pat = wildcard() + ctx, rewriter = get_qkv_proj_rewriter() - matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) - matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) - matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + rewritten = rewrite_bindings(ctx, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) - rewriter = get_qkv_proj_rewriter( - inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 - ) - rewritten = rewrite_bindings(ctx, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) - - # make sure it builds - mod = tvm.IRModule() - mod["main"] = rewritten + # make sure it builds + mod = tvm.IRModule() + mod["main"] = rewritten - rx.build(mod, target="llvm") + rx.build(mod, target="llvm") def test_combine_transposed_matmul_twice(): From 031f0475bea40f6dfb07c7d53e7078edfcbd300d Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Thu, 1 Aug 2024 11:42:49 -0400 Subject: [PATCH 040/202] [Runtime] Allow aborting fetchWithCache through AbortSignal (#17227) [Runtime] Add AbortSignal to fetchWithCache() --- web/src/artifact_cache.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 9690ed3320b9..794efdcedbc6 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -114,10 +114,11 @@ export class ArtifactCache implements ArtifactCacheTemplate { * fetch the corresponding url object in response or stored object format * @param url url * @param storetype the storage type for indexedDB + * @param signal an optional abort signal to abort fetching * @returns response in json, arraybuffer or pure response format */ - async fetchWithCache(url: string, storetype?: string): Promise { - await this.addToCache(url, storetype); + async fetchWithCache(url: string, storetype?: string, signal?: AbortSignal): Promise { + await this.addToCache(url, storetype, signal); const result = await this.cache.match(new Request(url)); if (result === undefined) { // Already called `addToCache()`, should expect the request in cache. @@ -242,8 +243,8 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { }) } - async fetchWithCache(url: string, storetype?: string): Promise { - await this.addToCache(url, storetype); + async fetchWithCache(url: string, storetype?: string, signal?: AbortSignal): Promise { + await this.addToCache(url, storetype, signal); let result = await this.asyncGetHelper(url); if (result === null) { // previously null data in cache or somehow failed to add to cache, delete and retry From 3a02309ed85d308da1b1af127bc97b5b22589a43 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 2 Aug 2024 22:14:32 +0800 Subject: [PATCH 041/202] [Relax] FuseTransposeMatmul Pass (#17234) Introduce a new pass to fuse transpose and matmul, which specially for `Linear` ops in PyTorch and NNModule. Note that this pass is migrated from MLC-LLM. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao --- python/tvm/relax/transform/__init__.py | 1 + .../relax/transform/fuse_transpose_matmul.py | 175 ++++++++++++++++++ .../test_transform_fuse_transpose_matmul.py | 82 ++++++++ 3 files changed, 258 insertions(+) create mode 100644 python/tvm/relax/transform/fuse_transpose_matmul.py create mode 100644 tests/python/relax/test_transform_fuse_transpose_matmul.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5e76fff6bd1e..5789e2fcf235 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -90,6 +90,7 @@ from .optimize_layout_transform import OptimizeLayoutTransform from .remove_redundant_reshape import RemoveRedundantReshape from .fast_math import FastMathTransform +from .fuse_transpose_matmul import FuseTransposeMatmul from .attach_external_modules import AttachExternModules # Import to register the legalization functions. diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py b/python/tvm/relax/transform/fuse_transpose_matmul.py new file mode 100644 index 000000000000..1d2324a28b3e --- /dev/null +++ b/python/tvm/relax/transform/fuse_transpose_matmul.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""A compiler pass that fuses transpose + matmul and generate TIR function. +Note that +1. Please put the pass before LegalizeOps pass. +2. The pass only works for XW^T but not X^TW +3. The pass would rewrite the relax ops into TIR functions. If you'd like to dispatch the + ops into library (e.g. cuBLAS) calls, please run dispatch pass before this pass. +""" + +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") +class FuseTransposeMatmul: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + matmul.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + mod = relax.transform.FuseOpsByPattern( + [ + ( + "transpose_matmul_fuse", + *_pattern(), + ), + ] + )(mod) + transpose_matmul_codegen = _TransposeMatmulFuser(mod) + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(g_var, func) + return transpose_matmul_codegen.builder_.get() + + +def _pattern(): + """Pattern for transpose + matmul.""" + # pylint: disable=invalid-name + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(x, wT) + # pylint: enable=invalid-name + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + transpose_call = context.annotated_expr["wT"] + ndim = transpose_call.args[0].struct_info.ndim + if ndim == -1: + return False + if ndim == 2 and transpose_call.attrs.axes is None: + return True + axes = list(range(ndim)) + axes[-1], axes[-2] = axes[-2], axes[-1] + return list(transpose_call.attrs.axes) == axes + + return o, annotations, _check + + +# pylint: disable=missing-docstring,invalid-name + + +@mutator +class _TransposeMatmulFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod): + super().__init__(mod) + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(a_relax, bT_relax) + ).struct_info.shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + dim_equal = a_dim == b_dim + if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + else: + a_indices.append(idx_spatial[i]) + b_indices.append(idx_spatial[i]) + + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + b_indices.append(idx_reduce) + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[1], + call.args[0], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py b/tests/python/relax/test_transform_fuse_transpose_matmul.py new file mode 100644 index 000000000000..4b2b1fff8aba --- /dev/null +++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-docstring + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_transform_fuse_transpose_matmul(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((128, 256), "float32"), + w: R.Tensor((128, 256), "float32"), + ) -> R.Tensor((128, 128), "float32"): + with R.dataflow(): + wT = R.permute_dims(w, [1, 0]) + o = R.matmul(x, wT) + R.output(o) + return o + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def NT_matmul( + x: T.Buffer((T.int64(128), T.int64(256)), "float32"), + w: T.Buffer((T.int64(128), T.int64(256)), "float32"), + NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(x[v_i0, v_k], w[v_i1, v_k]) + T.writes(NT_matmul[v_i0, v_i1]) + with T.init(): + NT_matmul[v_i0, v_i1] = T.float32(0) + NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * w[v_i1, v_k] + + @R.function + def main( + x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32") + ) + R.output(gv) + return gv + + after = tvm.ir.transform.Sequential( + [ + relax.transform.FuseTransposeMatmul(), + relax.transform.FuseTIR(), # Only used for remove unused primitive function + ] + )(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() From 219ae85d4b58c97b3438fc9c031728c78002d9ad Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Fri, 2 Aug 2024 17:49:01 -0400 Subject: [PATCH 042/202] [Runtime Patch] Add AbortSignal to fetchWithCache in ArtifactCacheTemplate interface (#17233) [Runtime] Add AbortSignal to fetchWithCache in ArtifactCacheTemplate interface --- web/src/artifact_cache.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 794efdcedbc6..61ad021c7fef 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -47,11 +47,12 @@ export interface ArtifactCacheTemplate { * return the actual data object rather than the request. There are two options: * 1. "json": returns equivalent to `fetch(url).json()` * 2. "arraybuffer": returns equivalent to `fetch(url).arraybuffer()` + * @param signal: An optional AbortSignal allowing user to abort the fetching before its completion. * @return The data object (i.e. users do not need to call `.json()` or `.arraybuffer()`). * * @note This is an async function. */ - fetchWithCache(url: string, storetype?: string): Promise; + fetchWithCache(url: string, storetype?: string, signal?: AbortSignal): Promise; /** * Fetch data from url and add into cache. If already exists in cache, should return instantly. From 76b954a09e781b7f664b1d345e1494123c19484c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 3 Aug 2024 04:28:02 -0400 Subject: [PATCH 043/202] [3rdparty] Bump FlashInfer (#17236) This PR bumps FlashInfer and updates PagedKVCache accordingly for performance improvement. Some notes on this bump: * When the Grouped-Query Attention group size is at least 4 and FlashInfer is enabled, we use the prefill attn kernel for better performance. * We enlarge the temporary workspace for FlashInfer use accordingly, as FlashInfer in the current version may consume much larger workspace. We turn off the workspace when FlashInfer is not enabled. * We reduce the max block depth to be 2, in observation of the limited help of cascade inference when batch size is not large and the prompt reuse is low. --- 3rdparty/flashinfer | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 48 +++++++++++++------ ...tin_paged_attention_kv_cache_flashinfer.py | 13 ++++- ...me_builtin_paged_attention_kv_cache_tir.py | 13 ++++- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 7e9cc7ff42ca..0dd801d2027a 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2 +Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 2fb8a72f4279..5aa1411ec154 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -54,11 +54,11 @@ namespace relax_vm { * \brief The maximum allowed block depth (a.k.a. number of common * prefixes) in paged KV cache. */ -constexpr const int kPagedKVCacheMaxBlockDepth = 5; +constexpr const int kPagedKVCacheMaxBlockDepth = 2; /*! \brief The maximum tree size of a single sequence in tree attention. */ constexpr const int kTreeAttnMaxTreeSize = 256; /*! \brief The 8MB workspace size for attention auxiliary data. */ -constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024; +constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ constexpr const int kPagedKVCacheTempPageId = -1; @@ -119,6 +119,9 @@ struct Block { void Reset() { page_ids.clear(); seq_length = 0; + start_pos = 0; + sink_length = 0; + sliding_window_offset = 0; parent_idx = -1; external_ref_cnt = 0; } @@ -169,11 +172,9 @@ struct Sequence { this->last_block_idx = last_block_idx; int32_t block_ptr = last_block_idx; // Go through each block in the sequence, sum up the length. - int depth = 0; while (true) { const Block& block = global_block_pool->at(block_ptr); this->seq_length += block.seq_length; - ++depth; if (block.parent_idx == -1) { break; } @@ -1078,8 +1079,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { dtype_aux_, preferred_host_device); for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + if (NeedKernelBeginForward()) { + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + } qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); @@ -1087,8 +1090,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_view_.push_back(NDArray()); } // Additional workspace for the "prefill with ragged kv" kernel. - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + if (NeedKernelBeginForward()) { + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + } temp_attn_q_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device); @@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0]; + if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) { + // When GQA group size is at least 4 and FlashInfer is enabled, + // we always use prefill kernel for better performance. + std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false); + } + if (append_before_attn_) { // Right now we use different kernels when depth is 1 or not 1. // For the case where maximum depth is 1, we create the auxiliary @@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { use_decode_kernel}; } + /*! \brief Check whether BeginForward for kernels is needed. */ + bool NeedKernelBeginForward() { + return f_attention_prefill_begin_forward_.defined() && + f_attention_decode_begin_forward_.defined() && + f_attention_prefill_ragged_begin_forward_.defined(); + } + /*! \brief Invoke the "begin forward" functions of underlying kernels. */ void KernelBeginForward() { - if (!f_attention_prefill_begin_forward_.defined() || - !f_attention_decode_begin_forward_.defined() || - !f_attention_prefill_ragged_begin_forward_.defined()) { + if (!NeedKernelBeginForward()) { return; } @@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } else { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, - num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, + num_kv_heads_, head_dim_, copy_stream_); if (support_sliding_window_) { return; } @@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { f_attention_prefill_begin_forward_.value()( /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, - copy_stream_); + page_indptr_on_depths_host_[d].as_ndarray(), + static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, + num_kv_heads_, head_dim_, page_size_, copy_stream_); } } } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index bade04a7d753..cab10f84cddf 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -29,7 +29,7 @@ from tvm.script import tir as T reserved_nseq = 32 -maximum_total_seq_length = 1024 +maximum_total_seq_length = 2048 prefill_chunk_size = 512 page_size = 16 num_layers = 4 @@ -249,6 +249,7 @@ def copy_single_page( ): for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), @@ -662,6 +663,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode): cached_v.pop(i) verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v) + # Test fork after page recycle + apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) + @pytest.mark.skip(reason="Require FlashInfer enabled") def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode): diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 9192bb901ff0..3c85a13e4cfc 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -33,7 +33,7 @@ from tvm.target import Target reserved_nseq = 32 -maximum_total_seq_length = 1024 +maximum_total_seq_length = 2048 prefill_chunk_size = 512 page_size = 16 num_layers = 4 @@ -615,6 +615,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + # Test fork after page recycle + apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) + @tvm.testing.requires_gpu @tvm.testing.requires_cuda @@ -2547,6 +2557,7 @@ def copy_single_page( ): for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), From 21c12fb1243a79df2aea8b83956c6b0b914cf4a5 Mon Sep 17 00:00:00 2001 From: senlyu163 <70838408+senlyu163@users.noreply.github.com> Date: Sat, 3 Aug 2024 20:45:36 +0800 Subject: [PATCH 044/202] [Bugfix][Cutlass] fix cutlass instantiate attention template bugs (#17229) [Bugfix][Cutlass] fix cutlass attention template --- python/tvm/contrib/cutlass/attention_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 518778ec52ed..69298453cb87 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -111,7 +111,7 @@ def instantiate_attention_template(attrs): if (accumulator_buf_size <= ${workspace}->shape[0]) { p.output_accum_ptr = static_cast(${workspace}->data); } else { - accumulator_buf_size = true; + accumulator_buf_allocated = true; cudaMalloc( &p.output_accum_ptr, accumulator_buf_size From cd09ab64b5ccf6ff0a96d887a968acd4602188a8 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 3 Aug 2024 20:01:01 -0400 Subject: [PATCH 045/202] [Runtime] Reorganize PagedKVCache attn kernel invocation (#17237) This PR reorganizes the attention kernel invocation logic in the PagedKVCache, so that in cases of sequence fork, we can effectively merge one ragged-prefill kernel and a decode kernel into a single decode kernel. --- src/relax/transform/fuse_ops.cc | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 127 +++++++++++++------------ 2 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index e791aeab061d..85c739e08353 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -646,7 +646,7 @@ class FunctionCreator : public ExprMutator { return tvm::tir::UndefinedVars(prim_value->value).empty(); } else if (const auto* shape_expr = expr.as()) { return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), - [this](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); + [](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); } return false; } diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 5aa1411ec154..cf5de97202cc 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_); } - append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0]; + append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back(); if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) { // When GQA group size is at least 4 and FlashInfer is enabled, // we always use prefill kernel for better performance. @@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return; } - if (append_before_attn_) { - if (!support_sliding_window_) { + if (!append_before_attn_) { + if (is_chain_) { + f_attention_prefill_ragged_begin_forward_.value()( + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, + num_kv_heads_, head_dim_, copy_stream_); + } else { + LOG(FATAL) << "Kernel BeginForward doesn't support tree attn."; + } + } + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; + } + CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_host_[0].as_ndarray(), - last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, + d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); - } - } else { - f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), - cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, copy_stream_); - if (support_sliding_window_) { - return; - } - for (int d = 0; d < num_depths_; ++d) { - if (page_indices_on_depths_view_[d]->shape[0] == 0) { - continue; - } - if (use_decode_kernel_[d]) { - f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), - last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, - head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); - } else { - f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - page_indptr_on_depths_host_[d].as_ndarray(), - static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, - num_kv_heads_, head_dim_, page_size_, copy_stream_); - } + } else { + f_attention_prefill_begin_forward_.value()( + /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), + page_indptr_on_depths_host_[d].as_ndarray(), + static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, + num_kv_heads_, head_dim_, page_size_, copy_stream_); } } } @@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_decode = !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; - if (append_before_attn_) { - f_decode( - /*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0], - page_indices_on_depths_view_[0], length_info_on_depths_view_[0], - k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } else { - // Compute appended text self-attention + + bool is_first_kernel = true; + if (!append_before_attn_) { + // The first part of attention, which only involves the q and the newly appended k/v. + is_first_kernel = false; if (is_chain_) { // If the batch does not form a tree, use raggedness prefill kernel. f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, @@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); } + } - for (int d = 0; d < num_depths_; ++d) { - if (page_indices_on_depths_view_[d]->shape[0] == 0) { - continue; - } - if (use_decode_kernel_[d]) { - // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, - temp_attn_scores_view_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } else { - // Use prefill kernel for depth d - f_prefill( - /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, - temp_attn_output_view_, temp_attn_scores_view_, - /*causal=*/0, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; + } + NDArray attn_output; + NDArray attn_scores; + if (is_first_kernel) { + attn_output = output; + attn_scores = merged_attn_scores_view_; + } else { + attn_output = temp_attn_output_view_; + attn_scores = temp_attn_scores_view_; + } + if (use_decode_kernel_[d]) { + // Use decode kernel for depth d + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], + page_indices_on_depths_view_[d], length_info_on_depths_view_[d], + k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor); + } else { + // Use prefill kernel for depth d + f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], + q_rope_position_map_view_, attn_output, attn_scores, /*causal=*/0, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor); + } + + if (!is_first_kernel) { f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_, temp_attn_scores_view_); + } else { + is_first_kernel = false; } } } From bd7f1f8de046d598bcf15ea6d7dffc596d5119a4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 01:27:37 -0500 Subject: [PATCH 046/202] [TIR] Validate tir::Buffer axis_separators on construction (#17219) * [TIR] Validate tir::Buffer axis_separators on construction Prior to this commit, the `axis_separators` field of a TIR buffer wasn't validated until the `tir.FlattenBuffer` legalization pass. Delaying the error until this point makes it difficult to determine where it invalid `axis_separators` were initially defined. This commit updates the `tir::Buffer` constructor to validate the `axis_separators` field immediately, allowing these invalid values to be caught on construction. Closes https://github.com/apache/tvm/issues/17215 * Update metaschedule primitive to only set axis_separators of alloc * Allow axis separators to be increasing, rather than strictly increasing --- src/tir/ir/buffer.cc | 45 ++++++++++++------- .../primitive/layout_transformation.cc | 15 ++++--- tests/python/tir-base/test_tir_buffer.py | 12 +++-- .../test_tir_schedule_set_axis_separator.py | 4 +- 4 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 025605333138..b7c4eb1d42ec 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -334,24 +334,37 @@ inline Array BufferOffset(const BufferNode* n, Array index, return offsets; } -Buffer Buffer::GetFlattenedBuffer() const { - auto self = operator->(); - +static void ValidateAxisSeparators(const Array& axis_separators, size_t buffer_dim) { // These checks ensure that all output axes contain at least one // input axis. - for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { - auto sep = self->axis_separators[i]->value; - auto next_sep = self->axis_separators[i + 1]->value; - ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order."; - } - if (self->axis_separators.size()) { - auto first_sep = self->axis_separators[0]->value; - ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " - << "so that first output axis contains at least one input axis"; - auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; - ICHECK_LT(last_sep, self->shape.size()) - << "Last output axis must contain at least one input axis."; + for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { + auto sep = axis_separators[i]->value; + auto next_sep = axis_separators[i + 1]->value; + CHECK_LE(sep, next_sep) << "ValueError: " + << "Axis separators must be in increasing order, " + << "but axis_separators[" << i << "] = " << sep + << " is greater than or equal to axis_separators[" << (i + 1) + << "] = " << next_sep << "."; + } + if (axis_separators.size()) { + auto first_sep = axis_separators[0]->value; + CHECK_GE(first_sep, 0) << "ValueError: " + << "All axis separators must be non-negative. " + << "However, the axis_separators[0] = " << first_sep; + auto last_sep = axis_separators[axis_separators.size() - 1]->value; + CHECK_LE(last_sep, buffer_dim) + << "ValueError: " + << "All axis separators must be within the range " + << "0 <= sep <= buffer_dim. " + << "However, the last axis_separators[" << (axis_separators.size() - 1) + << "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim; } +} + +Buffer Buffer::GetFlattenedBuffer() const { + auto self = operator->(); + + ValidateAxisSeparators(self->axis_separators, self->shape.size()); Array output_shape; if (self->strides.size()) { @@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array ICHECK(data->type_annotation.as()->element_type.as()) << "Variable " << data->name_hint << " does not point to a primitive."; + ValidateAxisSeparators(axis_separators, shape.size()); + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index f1e9106a635b..8b95e0dc622f 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (it != buffer_var_map_.end()) { const Buffer& new_source_buffer = it->second; Buffer new_target_buffer = match_buffer->buffer; - new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; - if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) { - LOG(WARNING) - << "Target buffer in match_buffer doesn't have the same dimensionality as its source " - "buffer. `axis_separators` for the target buffer might be incorrect."; + + if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) { + new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; + } else { + new_target_buffer.CopyOnWrite()->axis_separators = + Array(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); + LOG(WARNING) << "Buffer view " << new_target_buffer + << " has different dimensionality than backing buffer " << new_source_buffer + << ". The `axis_separators` for " << new_target_buffer << "." + << "`axis_separators` for the view might be incorrect."; } buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer; return MatchBufferRegion(new_target_buffer, diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index 1ab7662b0b6b..b4b773197b14 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod(): A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): - tvm.ir.assert_structural_equal( - index_simplified, index_direct - ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct) + ( + tvm.ir.assert_structural_equal(index_simplified, index_direct), + "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct), + ) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators(): tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) +def test_invalid_axis_separators_raises_exception(): + with pytest.raises(ValueError): + tvm.tir.decl_buffer([1], axis_separators=[1, 2]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py index 76a6ade42f50..788e17e77146 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py @@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) C[vi, vj] = B_subregion1[()] + T.float32(1) From 5a67a00bcbb53731bbf53db7801fa16c8c9eb9f2 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 5 Aug 2024 21:17:48 +0800 Subject: [PATCH 047/202] [Unity][Frontend] Add Sqrt Op (#17228) * Update op.py * Update test_frontend_nn_op.py * Update op.py with annotation * Update core.py(typo in annotation) --- python/tvm/relax/frontend/nn/core.py | 2 +- python/tvm/relax/frontend/nn/op.py | 22 ++++++++++++++++++++++ tests/python/relax/test_frontend_nn_op.py | 6 ++++-- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 3511c38a2b7c..21118b1cb8af 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -17,7 +17,7 @@ """The core infra for nn.Module, which includes the following pieces: - Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more convenient access shape and dtype information. - Tensor is always symbolc and not bound to any concrete values. + Tensor is always symbolic and not bound to any concrete values. - Parameter, a special tensor which could be bound or not bound to concrete values. - Module, a container of nn.Parameters and sub nn.Modules. - Effect, a non-user-facing class that encloses potential side effects, for example, IO, diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index e1ba4483c741..17a40a8cce57 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1486,6 +1486,28 @@ def square(x: Tensor, name: str = "square") -> Tensor: return wrap_nested(_op.square(x._expr), name) +def sqrt(x: Tensor, name: str = "sqrt") -> Tensor: + """Computes the element-wise sqrt of the input tensor. + + Parameters + ---------- + x : Tensor + The input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.sqrt(x._expr), name) + + def get_timestep_embedding( x: Tensor, embedding_dim: int, diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index a632a867432b..6c3269195498 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -31,7 +31,8 @@ def test_unary(): class Model(Module): def test(self, x: Tensor): z0 = op.square(x) - return (x,) + z1 = op.sqrt(x) + return (z0, z1) # fmt: off @R.function @@ -39,7 +40,8 @@ def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object): R.func_attr({"num_input": 2}) with R.dataflow(): square: R.Tensor((1, 10), dtype="float32") = R.square(x) - gv1 = (x,), (_io,) + sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x) + gv1 = (square, sqrt), (_io,) R.output(gv1) return gv1 # fmt: on From 5f22be4d83ca698e316ac342f32f5b4d38155ca8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 08:19:20 -0500 Subject: [PATCH 048/202] [FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183) * [Container] Support non-nullable types in Array::Map Prior to this commit, the `Array::Map` member function could only be applied to nullable object types. This was due to the internal use of `U()` as the default value for initializing the output `ArrayNode`, where `U` is the return type of the mapping function. This default constructor is only available for nullable types, and would result in a compile-time failure for non-nullable types. This commit replaces `U()` with `ObjectRef()` in `Array::Map`, removing this limitation. Since all items in the output array are overwritten before returning to the calling scope, initializing the output array with `ObjectRef()` does not violate type safety. * [FFI] Separate runtime types from IR types for int/float/bool Prior to this commit, `int`, `float`, and `bool` arguments from Python were converted to `IntImm`, `FloatImm`, and `Bool`. These are subtypes of `PrimExpr`, and should only be used at compile-time. By automatically applying this conversion as part of the FFI, these types are required to be present whenever a primitive is converted to a `tvm::ObjectRef`. This can become especially fragile for an end-user when storing objects into a TVM container. Because TVM containers require all contents to be `ObjectRef` subclasses, an automatic conversion may be applied on storing into a container, resulting in an unexpected type being retrieved from the container. For example, this currently occurs in Relax when extracting a `R.Prim` from a `R.Tuple`. This commit introduces a `Box` type for storage of boxed primitives at runtime, distinct from the IR types. * Primitive arguments provided to a PackedFunc that requires an `ObjectRef` will be converted to the corresponding boxed type. (e.g. Passing a Python `int` to a C++ function accepting `ObjectRef` produces a `Box`. * Boxed primitives provided to a PackedFunc that requires an unboxed primitive will be converted to the corresponding primitive. * PackedFunc return values of `ObjectRef` are converted to the corresponding primitive, if present. (e.g. If a `tuple_getitem` with static return type `ObjectRef` returns a `Box`, it will be unwrapped to a python `int`.) Together, these three rules provide backwards compatibility for existing PackedFunc definitions, while avoiding exposing the user to any container-induced type conversions betweeen primitive types and `ObjectRef`. * Fix unit test failure after merge * Fix breakage in new unit test --- include/tvm/ir/attrs.h | 76 +- include/tvm/ir/expr.h | 130 +++- include/tvm/ir/transform.h | 34 +- include/tvm/meta_schedule/schedule_rule.h | 8 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/runtime/c_runtime_api.h | 5 +- .../tvm/runtime/container/boxed_primitive.h | 143 ++++ include/tvm/runtime/container/variant.h | 2 +- include/tvm/runtime/ndarray.h | 2 + include/tvm/runtime/packed_func.h | 689 ++++++++++++++---- include/tvm/target/target.h | 10 +- include/tvm/target/target_kind.h | 4 +- include/tvm/tir/expr.h | 57 ++ include/tvm/tir/function.h | 2 +- include/tvm/tir/schedule/schedule.h | 5 +- python/tvm/_ffi/_ctypes/object.py | 22 + python/tvm/_ffi/_ctypes/packed_func.py | 7 +- python/tvm/_ffi/_ctypes/types.py | 3 + python/tvm/_ffi/_cython/base.pxi | 5 +- python/tvm/_ffi/_cython/object.pxi | 10 + python/tvm/_ffi/_cython/packed_func.pxi | 9 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/driver/tvmc/registry.py | 22 +- python/tvm/ir/attrs.py | 2 +- python/tvm/ir/expr.py | 5 +- python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/op/statistical.py | 22 +- python/tvm/relax/testing/ast_printer.py | 18 +- python/tvm/relax/training/setup_trainer.py | 4 +- python/tvm/relax/utils.py | 3 + .../relay/backend/contrib/ethosu/legalize.py | 2 +- python/tvm/relay/op/_tensor_grad.py | 3 + python/tvm/relay/op/_transform.py | 8 +- python/tvm/relay/op/contrib/ethosu.py | 4 +- python/tvm/relay/op/transform.py | 25 +- .../transform/fake_quantization_to_integer.py | 5 +- python/tvm/runtime/__init__.py | 4 +- python/tvm/runtime/container.py | 38 + python/tvm/runtime/object_generic.py | 75 +- python/tvm/script/parser/tir/parser.py | 2 + python/tvm/te/hybrid/calls.py | 12 +- python/tvm/te/hybrid/parser.py | 4 +- python/tvm/te/hybrid/utils.py | 28 +- python/tvm/te/operation.py | 1 - python/tvm/te/tensor.py | 11 +- python/tvm/tir/__init__.py | 1 + python/tvm/tir/expr.py | 4 + python/tvm/tir/ir_builder.py | 6 +- python/tvm/tir/op.py | 151 ++-- python/tvm/tir/schedule/trace.py | 15 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 2 +- python/tvm/topi/cuda/batch_matmul.py | 8 +- rust/tvm-rt/src/module.rs | 5 +- rust/tvm-sys/src/packed_func.rs | 35 +- src/auto_scheduler/compute_dag.cc | 16 +- .../search_policy/sketch_policy_rules.cc | 3 +- src/auto_scheduler/search_policy/utils.h | 12 +- .../msc/core/printer/msc_base_printer.cc | 9 + .../msc/core/printer/prototxt_printer.cc | 4 + src/contrib/msc/core/utils.cc | 4 + src/driver/driver_api.cc | 5 +- src/ir/attrs.cc | 89 +++ src/ir/expr.cc | 17 +- src/ir/transform.cc | 41 +- src/meta_schedule/database/database_utils.cc | 10 +- src/meta_schedule/database/json_database.cc | 4 +- .../mutator/mutate_thread_binding.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 6 +- src/meta_schedule/mutator/mutate_unroll.cc | 6 +- .../schedule/cuda/thread_bind.cc | 6 +- .../schedule_rule/cross_thread_reduction.cc | 8 +- .../schedule_rule/multi_level_tiling.cc | 5 +- .../parallel_vectorize_unroll.cc | 6 +- .../schedule_rule/schedule_rule.cc | 12 +- src/meta_schedule/utils.h | 38 +- src/node/boxed_primitive.cc | 134 ++++ src/node/script_printer.cc | 16 +- src/node/structural_equal.cc | 37 +- src/relax/backend/vm/codegen_vm.cc | 2 + src/relax/backend/vm/codegen_vm_tir.cc | 30 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/manipulate.cc | 6 +- src/relax/op/tensor/manipulate.h | 4 +- .../backend/contrib/cmsisnn/compiler_attrs.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- src/relay/backend/contrib/cutlass/target.cc | 18 +- .../backend/contrib/ethosn/ethosn_api.cc | 6 +- src/relay/backend/contrib/ethosu/codegen.cc | 3 +- .../backend/contrib/ethosu/preprocess.cc | 4 +- .../contrib/example_target_hooks/target.cc | 2 +- src/relay/backend/contrib/tensorrt/codegen.cc | 4 +- src/relay/backend/contrib/tensorrt/target.cc | 14 +- src/relay/backend/contrib/uma/targets.cc | 7 +- src/relay/backend/executor.cc | 10 +- src/relay/backend/runtime.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 36 + src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 48 +- .../transforms/combine_parallel_op_batch.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/higher_order_gradient.cc | 2 - src/relay/transforms/to_mixed_precision.cc | 4 +- src/runtime/boxed_primitive.cc | 65 ++ src/runtime/crt/common/crt_runtime_api.c | 8 +- src/runtime/disco/bcast_session.cc | 8 +- src/runtime/minrpc/rpc_reference.h | 8 + src/runtime/relax_vm/builtin.cc | 10 +- .../printer/doc_printer/python_doc_printer.cc | 23 +- src/script/printer/ir/misc.cc | 15 + src/script/printer/relax/tir.cc | 6 +- src/support/array.h | 52 +- src/support/ffi_testing.cc | 52 ++ src/target/llvm/codegen_cpu.cc | 29 +- src/target/llvm/llvm_instance.cc | 14 +- src/target/tag.cc | 66 +- src/target/target.cc | 66 +- src/target/target_kind.cc | 137 ++-- src/te/operation/compute_op.cc | 26 +- src/te/operation/create_primfunc.cc | 15 +- src/te/operation/placeholder_op.cc | 12 +- src/te/schedule/schedule_dataflow_rewrite.cc | 7 +- .../analysis/calculate_allocated_memory.cc | 2 +- src/tir/ir/expr.cc | 20 +- src/tir/ir/function.cc | 7 + src/tir/ir/specialize.cc | 2 +- src/tir/ir/stmt.cc | 32 +- src/tir/ir/utils.cc | 68 ++ src/tir/ir/utils.h | 51 ++ src/tir/op/op.cc | 16 +- src/tir/schedule/concrete_schedule.cc | 14 +- src/tir/schedule/concrete_schedule.h | 5 +- src/tir/schedule/instruction_traits.h | 5 + src/tir/schedule/primitive.h | 5 +- src/tir/schedule/primitive/annotate.cc | 3 + src/tir/schedule/primitive/sampling.cc | 36 +- src/tir/schedule/trace.cc | 12 +- src/tir/schedule/traced_schedule.cc | 6 +- src/tir/schedule/traced_schedule.h | 5 +- .../transforms/inline_private_functions.cc | 2 +- src/tir/transforms/ir_utils.h | 1 + src/tir/transforms/lower_tvm_builtin.cc | 2 + src/tir/transforms/make_packed_api.cc | 45 +- tests/cpp/relay/backend/runtime_test.cc | 10 +- tests/cpp/target_test.cc | 56 +- .../test_runtime_packed_func.py | 18 +- .../arith/test_arith_canonical_simplify.py | 23 +- .../arith/test_arith_iter_affine_map.py | 35 +- .../test_arith_narrow_predicate_expression.py | 21 +- .../arith/test_arith_rewrite_simplify.py | 63 +- .../test_arith_solve_linear_equations.py | 15 +- .../test_arith_solve_linear_inequality.py | 11 +- .../codegen/test_target_codegen_cuda.py | 2 +- .../codegen/test_target_codegen_llvm.py | 41 ++ .../ir/test_container_structural_equal.py | 30 +- tests/python/ir/test_ir_container.py | 15 +- tests/python/ir/test_ir_type.py | 9 +- .../test_distributed_tvmscript_printer.py | 4 +- tests/python/relax/test_ast_printer.py | 2 +- .../relax/test_backend_dispatch_sort_scan.py | 10 +- .../relax/test_tvmscript_printer_relax.py | 6 +- tests/python/relax/test_vm_build.py | 2 +- tests/python/relax/test_vm_codegen_tir.py | 5 +- tests/python/relay/test_dataflow_pattern.py | 3 +- tests/python/relay/test_executor.py | 2 +- tests/python/relay/test_runtime.py | 4 +- tests/python/relay/test_type_infer.py | 65 +- .../python/runtime/test_runtime_container.py | 130 +++- tests/python/te/test_te_schedule_tensorize.py | 20 +- tests/python/te/test_te_tag.py | 10 +- tests/python/tir-base/test_lower_build.py | 2 +- tests/python/tir-base/test_tir_buffer.py | 17 +- tests/python/tir-base/test_tir_index_map.py | 55 +- tests/python/tir-base/test_tir_nodes.py | 27 +- .../test_tir_schedule_sampling.py | 2 +- .../tir-schedule/test_tir_schedule_state.py | 4 +- ...est_tir_transform_compact_buffer_region.py | 71 +- ...tir_transform_instrument_bound_checkers.py | 8 +- .../test_tir_transform_make_packed_api.py | 139 ++++ .../test_tir_transform_storage_rewrite.py | 4 +- .../tvmscript/test_tvmscript_error_report.py | 17 +- .../tvmscript/test_tvmscript_printer_tir.py | 12 +- .../tvmscript/test_tvmscript_roundtrip.py | 31 +- vta/python/vta/transform.py | 13 +- 184 files changed, 3215 insertions(+), 1221 deletions(-) create mode 100644 include/tvm/runtime/container/boxed_primitive.h create mode 100644 src/node/boxed_primitive.cc create mode 100644 src/runtime/boxed_primitive.cc create mode 100644 src/tir/ir/utils.cc create mode 100644 src/tir/ir/utils.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 81611b1a535a..d038d5f59a5f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -265,7 +265,16 @@ class DictAttrs : public Attrs { auto it = node->dict.find(attr_key); if (it != node->dict.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + Optional obj = ret; + return obj; } else { return default_value; } @@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the DictAttrs, but overrides attributes with the + * entries from \p attrs. + * + * \param attrs The DictAttrs to update + * + * \param new_attrs Key/values attributes to add to \p attrs. + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); + +/*! + * \brief Copy the DictAttrs, but overrides a single attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The update to insert or update. + * + * \param value The new value of the attribute + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value); + +inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) { + return WithAttr(std::move(attrs), String(key), std::move(value)); +} + +/*! + * \brief Copy the DictAttrs, but without a specific attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The key to remove + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); + /*! * \brief Copy the function or module, but overrides * the attribute value key with the value. @@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } + node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); + return input; } @@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - for (const auto& pair : attrs) { - node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); - } - } else { - node->attrs = DictAttrs(std::move(attrs)); - } + + node->attrs = WithAttrs(std::move(node->attrs), attrs); + return input; } @@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - if (input->attrs.defined()) { - TNode* node = input.CopyOnWrite(); - node->attrs.CopyOnWrite()->dict.erase(attr_key); - } + TNode* node = input.CopyOnWrite(); + node->attrs = WithoutAttr(std::move(node->attrs), attr_key); + return input; } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 9b522389227a..efde52385177 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// common rule for RetValue and ArgValue + +// Automatic conversion into IntImm, Integer, and Bool, when called +// through the FFI. Automatic conversions into PrimExpr are +// registered in "tvm/tir/expr.h", as it includes conversions to the +// TIR-only StringImm. +// +// While the FFI only requires the From() method, these +// implementations also define a TryFrom() method to avoid duplicate +// logic in the PrimExpr conversion. + template <> -struct PackedFuncValueConverter { - static PrimExpr From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return PrimExpr(ObjectPtr(nullptr)); - } - if (val.type_code() == kDLInt) { - int64_t value = val.operator int64_t(); - if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { - return IntImm(runtime::DataType::Int(64), value); - } - return IntImm(runtime::DataType::Int(32), val.operator int()); - } - if (val.type_code() == kDLFloat) { - return FloatImm(runtime::DataType::Float(32), val.operator double()); +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsInt()) { + int64_t value = opt.value(); + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); + } else if (auto opt = val.TryAsBool()) { + return IntImm(DataType::Int(32), opt.value()); + } else { + return NullOpt; } + } - return PrimExpr::FromObject_(val.AsObjectRef()); + template + static tvm::IntImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } } }; template <> struct PackedFuncValueConverter { - static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); + template + static tvm::Integer From(const PODSubclass& val) { + if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return Integer(opt.value()); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); } }; template <> struct PackedFuncValueConverter { - static tvm::Bool From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Bool(ObjectPtr(nullptr)); + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + return tvm::Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int value = opt.value(); + ICHECK(value == 0 || value == 1) + << "ValueError: boolean value can only be 0 or 1, but get " << value; + return tvm::Bool(static_cast(value)); + } else { + return NullOpt; + } + } + + template + static tvm::Bool From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - int v = val.operator int(); - ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; - return Bool(static_cast(v)); + } +}; + +template <> +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return NullOpt; + } + } + + template + static tvm::FloatImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +/* \brief Backwards compatibility wrapper for IntImm arguments + * + * In previous versions of TVM, IntImm was the default FFI type for + * integer arguments, instead of runtime::Int. For backwards + * compatibility where the callee has been updated to expected a + * runtime::Int, the caller has not been updated to provide a + * runtime::Int (e.g. relay script parsing), and the auto-unboxing of + * runtime::Int does not apply (e.g. making an `Array`), + * allow the IntImm to be generated. + */ +template <> +struct PackedFuncValueConverter { + template + static runtime::Int From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return runtime::Int(val.template AsObjectRef()->value); + } else { + return val.template AsObjectRef(); } - return val.AsObjectRef(); } }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index adf332525020..5828d98206ad 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -271,7 +271,36 @@ class PassContext : public ObjectRef { using ValueNodeType = typename ValueType::ContainerType; // NOTE: we could further update the function later. uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); - RegisterConfigOption(key, tindex); + auto type_key = runtime::Object::TypeIndex2Key(tindex); + + auto* reflection = ReflectionVTable::Global(); + + auto legalization = [=](ObjectRef obj) -> ObjectRef { + if (obj->IsInstance::ContainerType>()) { + return reflection->CreateObject(type_key, Downcast>(obj)); + } else { + // Backwards compatibility for config options defined prior to + // https://github.com/apache/tvm/pull/16183. This commit + // changed the default FFI conversion of python integers from + // `tvm::IntImm` to `runtime::Int`. + // + // This backwards compatibility fix can be removed when all + // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are + // updated to use `runtime::Int` and `runtime::Bool`. + TVMRetValue ret; + ret = obj; + try { + ValueType legalized = ret; + return legalized; + } catch (Error& err) { + LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key + << ", but received error when converting to this type.\n" + << err.what(); + } + } + }; + + RegisterConfigOption(key, tindex, legalization); return tindex; } @@ -285,7 +314,8 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization); // Classes to get the Python `with` like syntax. friend class Internal; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index d91812fb55cb..90aec05187eb 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 249b9cd0e50d..91020fc7443b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - ObjectRef indices_or_sections; + Variant> indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f1046ef24266..b4c653a0a59e 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -81,6 +81,7 @@ #ifdef __cplusplus extern "C" { #endif +#include #include #include @@ -186,11 +187,12 @@ typedef enum { kTVMBytes = 12U, kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, + kTVMArgBool = 15U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 15U, + kTVMExtBegin = 16U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, // The following section of code is used for non-reserved types. @@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; + bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h new file mode 100644 index 000000000000..8d01b5dc17b5 --- /dev/null +++ b/include/tvm/runtime/container/boxed_primitive.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/boxed_primitive.h + * \brief Runtime container types for primitives stored as ObjectRef. + */ +#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ +#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +/* \brief Provide the BoxNode type traits in templated contexts + * + * The Box class is used in many templated contexts, and is easier + * to have templated over the primitive type. + * + * However, much of the TVM type system depends on classes having a + * unique name. For example, the use of `Object::IsInstance` depends + * on `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will + * result in duplicate indices, and invalid downcasting. Furthermore, + * the name must be specified in the Python FFI using + * `tvm._ffi.register_object`. This prevents use of + * `typeid(T)::name()` to build a unique name, as the name is not + * required to be human-readable or consistent across compilers. + * + * This utility struct should be specialized over the primitive type + * held by the box, to allow explicit listing of the `_type_key` and + * other similar tratis. + * + * Note: This should only contain traits that are required at runtime, + * and should *not* contain extensions for features that are only + * available at compile-time. For integration with compile-time-only + * functionality (e.g. StructuralHash, StructuralEqual), see + * `BoxNodeCompileTimeTraits` in `src/node/boxed_primitive.cc`. + */ +template +struct BoxNodeRuntimeTraits; + +} // namespace detail + +template +class BoxNode : public Object { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + explicit BoxNode(Prim value) : value(value) {} + + /*! \brief The boxed value */ + Prim value; + + static constexpr const char* _type_key = detail::BoxNodeRuntimeTraits::_type_key; + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); +}; + +template +class Box : public ObjectRef { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + Box(Prim value) : ObjectRef(make_object>(value)) {} // NOLINT(*) + + operator Prim() const { return (*this)->value; } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); +}; + +/*! \brief Boxed version of C++ int64_t + * + * Can be used to store POD integer values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + */ +using Int = Box; + +/*! \brief Boxed version of C++ double + * + * Can be used to store POD floating-point values as a TVM ObjectRef. + * Used for FFI handling, and for storing POD types inside TVM + * containers. + */ +using Float = Box; + +/*! \brief Boxed version of C++ bool + * + * Can be used to store POD boolean values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + * + * When passing from Python to C++, TVM PackedFunc conversion follow + * C++ conversion rules, and allow bool->int and int->bool + * conversions. When passing from C++ to Python, the types are + * returned as bool or int. If the C++ function uses ObjectRef to + * hold the object, a Python to C++ to Python round trip will preserve + * the distinction between bool and int. + */ +using Bool = Box; + +namespace detail { +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxInt"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxFloat"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxBool"; +}; +} // namespace detail + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 7953ac47c1cf..e8defa4e6fee 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_same_v || ...); + static constexpr bool is_variant = (std::is_base_of_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 3eb225fccffe..fef61a753103 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -226,6 +226,8 @@ class NDArray : public ObjectRef { protected: friend class TVMPODValue_; + template + friend class TVMPODValue_CRTP_; friend class TVMRetValue; friend class TVMArgsSetter; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..98196c13af7f 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -429,9 +431,11 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) +#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ + "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) +#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -510,6 +514,7 @@ struct ObjectTypeChecker> { } static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; + template struct ObjectTypeChecker> { static Optional CheckAndGetMismatch(const Object* ptr) { @@ -545,40 +550,43 @@ struct ObjectTypeChecker> { } }; +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + return ObjectTypeChecker::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { return ObjectTypeChecker::Check(ptr); } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { return ObjectTypeChecker::TypeName(); } +}; + +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + auto try_first = ObjectTypeChecker::CheckAndGetMismatch(ptr); + if (!try_first.defined()) { + return try_first; + } + + return ObjectTypeChecker>::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { + return ObjectTypeChecker::Check(ptr) || + ObjectTypeChecker>::Check(ptr); + } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { + return ObjectTypeChecker::TypeName() + ", " + + ObjectTypeChecker>::VariantNames(); + } +}; + /*! * \brief Internal base class to * handle conversion to POD values. */ class TVMPODValue_ { public: - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (type_code_ == kDLInt) { - return static_cast(value_.v_int64); - } - TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); - return value_.v_float64; - } - operator int64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator uint64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator int() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - ICHECK_LE(value_.v_int64, std::numeric_limits::max()); - ICHECK_GE(value_.v_int64, std::numeric_limits::min()); - return static_cast(value_.v_int64); - } - operator bool() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64 != 0; - } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; @@ -628,12 +636,39 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; + + std::optional TryAsBool() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kTVMArgBool) { + return value_.v_bool; + } else { + return std::nullopt; + } + } + + std::optional TryAsInt() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLInt) { + return value_.v_int64; + } else { + return std::nullopt; + } + } + + std::optional TryAsFloat() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLFloat) { + return value_.v_float64; + } else { + return std::nullopt; + } + } protected: friend class TVMArgsSetter; @@ -648,13 +683,90 @@ class TVMPODValue_ { int type_code_; }; +/*! \brief A utility class that adds methods useful for each POD type + * + * These cannot be provided in the base PODValue_ class, because + * TVMArgValue and TVMRetValue have different semantics for kTVMStr + * and kTVMBytes. + * + * kTVMStr: + * + * For `TVMArgValue`, the active variant is `v_str`, a `const + * char*`. For `TVMRetValue`, the active variant is `v_handle`, + * and should be cast from `void*` to `std::string*`. + * + * kTVMBytes: + * + * The active variant is `v_handle`, a `void*`. For + * `TVMArgValue`, should be cast to `TVMByteArray*`. For + * `TVMRetValue`, should be cast to `std::string*`. + * + * When converting into an `ObjectRef`, a string may be used to build + * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use + * different representations for strings, any utility funciton which + * might attempt a conversion to an `ObjectRef` must be performed + * within a context that is aware of the derived class. + */ +template +class TVMPODValue_CRTP_ : public TVMPODValue_ { + public: + using TVMPODValue_::TVMPODValue_; + + // ObjectRef handling + template ::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (auto opt = TryAsFloat()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); + } + } + operator int64_t() const { + if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } + operator uint64_t() const { return operator int64_t(); } + operator int() const { + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; + } + operator bool() const { + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } +}; + /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ -class TVMArgValue : public TVMPODValue_ { +class TVMArgValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMArgValue() {} @@ -663,21 +775,21 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; // conversion operator. operator std::string() const { @@ -714,15 +826,15 @@ class TVMArgValue : public TVMPODValue_ { * * \note For internal development purpose only. */ -class TVMMovableArgValue_ : public TVMPODValue_ { +class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; @@ -804,7 +916,7 @@ class TVMMovableArgValueWithContext_ { * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ -class TVMRetValue : public TVMPODValue_ { +class TVMRetValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMRetValue() {} @@ -812,28 +924,28 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -901,8 +1013,8 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; + this->SwitchToPOD(kTVMArgBool); + value_.v_bool = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -974,7 +1086,8 @@ class TVMRetValue : public TVMPODValue_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || + type_code == kTVMArgBool); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -989,9 +1102,9 @@ class TVMRetValue : public TVMPODValue_ { } // ObjectRef handling template ::value>::type> + typename = typename std::enable_if_t>> inline TVMRetValue& operator=(TObjectRef other); - template ::value>::type> + template >> inline operator T() const; private: @@ -1019,9 +1132,11 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMObjectHandle: { - // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject(kTVMObjectHandle, - GetObjectPtr(static_cast(other.value_.v_handle))); + // We already known it is not NDArray/Module, but + // operator=(ObjectRef) also handles conversions from wrappers + // around primitive types. For NDArray/Module, the duplicate + // checks are removed with if constexpr. + operator=(other.operator ObjectRef()); break; } case kTVMObjectRValueRefArg: { @@ -1265,6 +1380,8 @@ inline const char* ArgTypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; + case kTVMArgBool: + return "bool"; case kDLUInt: return "uint"; case kDLFloat: @@ -1686,6 +1803,10 @@ class TVMArgsSetter { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } + TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { + values_[i].v_bool = value; + type_codes_[i] = kTVMArgBool; + } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); ICHECK_LE(value, static_cast(std::numeric_limits::max())); @@ -1951,38 +2072,110 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (value.defined()) { - Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + if (!value.defined()) { + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; + return; + } + + Object* ptr = value.data_.data_; + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - } else if (std::is_rvalue_reference::value) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; - } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + return; + } + } + + // Like with BoxInt, unwrap any BoxBool instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_bool = static_cast(ptr)->value; + type_codes_[i] = kTVMArgBool; + return; + } + } + + // If a boxed integer is being returned, always unbox it to the + // primitive type. This must be checked at the PackedFunc level to + // ensure that a boxed primitive argument is round-tripped correctly + // when the boxing is no longer required. + // + // For example, consider a PackedFunc with signature `ObjectRef + // func(Array)`, and returns the first element of that + // array. When passing a Python array `[5, 17.5, "hello"]`, the + // items are converted to `[Box(5), Box(17.5), + // String("hello")]` in order to provide an `Array`. + // + // If we had no additional conversions, the caller would receive the + // return value as a `Box(5)`, which would be unexpected and + // require additional unwrapping. We could perform this check + // inside the PackedFunc, but that would require a large amount of + // duplicated checked, and would require explicit handling of + // `TVMRetValue`. Instead, this conversion is checked in the FFI + // return value, to ensure that boxing/unboxing is applied + // consistently. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_int64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgInt; + return; + } + } + + // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_float64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgFloat; + return; } + } + + // Final fallback, if the ObjectRef has no special cases that must + // be expressed within the TVMRetValue. + if constexpr (std::is_rvalue_reference_v) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } } +template template -inline bool TVMPODValue_::IsObjectRef() const { +inline bool TVMPODValue_CRTP_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { @@ -2012,8 +2205,9 @@ inline bool TVMPODValue_::IsObjectRef() const { ObjectTypeChecker::Check(static_cast(value_.v_handle))); } +template template -inline TObjectRef TVMPODValue_::AsObjectRef() const { +inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { static_assert(std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; @@ -2023,8 +2217,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + + // NOTE: The following code uses "if constexpr" wherever possible to + // minimize the number of runtime checks. + if constexpr (std::is_base_of_v) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2033,7 +2229,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2041,7 +2238,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2049,6 +2247,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2062,51 +2261,152 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } else if (std::is_base_of::value && - type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else if (std::is_base_of::value && - type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgInt) { + return Int(value_.v_int64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgFloat) { + return Float(value_.v_float64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgBool) { + return Bool(value_.v_bool); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMStr || type_code_ == kTVMBytes) { + // This step is the reason why `AsObjectRef` cannot be provided + // in the base `TVMPODValue_` class. Because `TVMArgValue` and + // `TVMRetValue` have different implementations of `operator + // std::string`, with different interpretations of `kTVMStr` and + // `kTVMBytes`, we must delegate to those implementations. + // + // This could be done with a pure virtual method in + // `TVMPODValue_`, but that would require a vtable lookup during + // FFI conversions, imposing a runtime overhead. + return String(static_cast(this)->operator std::string()); + } + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(NDArray(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(Module(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(PackedFunc(std::move(other.data_))); + + if (ptr) { + // Check for special cases of ObjectRef that have explicit + // representation within the TVMRetValue structure. + // (e.g. Unboxing of `runtime::Int` into a primitive integer + // with type code kTVMArgInt.) The checks below are written to + // handle three distinct cases. + // + // 1. If TObjectRef is a subclass of TSpecialCase, the special + // case applies, and can be handled without a runtime check. + // No runtime checks should be performed. + // + // 2. If TSpecialCase is a subclass of TObjectRef, the special + // case might apply, and requires a runtime check. + // + // 3. If neither TObjectRef nor TSpecialCase is a subclass of + // the other, then the special case does not apply. No + // runtime checks should be performed. + // + // Use of `if constexpr` ensures that the C++ subclass checks + // are applied when compiling TVM, and runtime overhead are only + // present when they may be applicable. + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(NDArray(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(Module(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(PackedFunc(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + bool value = static_cast(ptr)->value; + return operator=(value); + } } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + int64_t value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + double value = static_cast(ptr)->value; + return operator=(value); + } + } + + // If the object being stored is not one of the special cases, + // it is stored as an ObjectRef. SwitchToObject(kTVMObjectHandle, std::move(other.data_)); + } else { + // No object is present, set to an explicitly null handle. When + // returning to a Python callee, this will be converted to + // `None`. SwitchToPOD(kTVMNullptr); value_.v_handle = nullptr; } + return *this; } @@ -2139,20 +2439,123 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); + template + static String From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return val.template AsObjectRef(); } else { return tvm::runtime::String(val.operator std::string()); } } +}; - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); +template +struct PackedFuncValueConverter> { + static Array From(const TVMArgValue& val) { + auto untyped_array = val.AsObjectRef>(); + + // Attempt to convert each item of the array into the desired + // type. If the items do not require a conversion, no copies are + // made. + return untyped_array.Map([](ObjectRef item) { + // Recursively apply any conversions that have been registered + // with TVM's FFI. + // + // For example, a function that accepts `Array` may + // be called from python with argument `[1,2]`. By the time + // `PackedFuncValueConverter::From` is called, the python list + // has been converted to `Array`, with contents + // converted into `runtime::Int`. Converting the `ObjectRef` + // to `TVMArgValue` unboxes the `runtime::Int` back into a + // primitive with type code `kTVMArgInt`. This primitive can + // then be converted to a PrimExpr using + // `PackedFuncValueConverter::From`. + // + // The use of two conversions, first from python `int` to + // `runtime::Int` and then from `runtime::Int` to `PrimExpr`, + // is a result of the split between `libtvm_runtime.so` and + // `libtvm.so`. The FFI must function correctly in both + // cases, and so conversions applied by default in the Python + // FFI implementation may only produce types that are + // available in both libraries. In the C++ FFI implementation + // (i.e. this file), libtvm.so may apply additional + // conversions that are not present in libtvm_runtime.so. + TVMValue value; + int type_code; + TVMArgsSetter setter(&value, &type_code); + setter(0, item); + TVMArgValue arg(value, type_code); + return PackedFuncValueConverter::From(arg); + }); + } + static Array From(const TVMRetValue& val) { + auto untyped_array = val.AsObjectRef>(); + + return untyped_array.Map([](ObjectRef item) { + TVMRetValue item_val; + item_val = std::move(item); + return PackedFuncValueConverter::From(item_val); + }); + } +}; + +template +struct PackedFuncValueConverter> { + static Map From(const TVMArgValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.first); + TVMArgValue pod_arg(pod_value, type_code); + return PackedFuncValueConverter::From(pod_arg); + }(); + U new_value = [&]() { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.second); + TVMArgValue key_arg(pod_value, type_code); + return PackedFuncValueConverter::From(key_arg); + }(); + output.Set(new_key, new_value); + } + return output; + } + static Map From(const TVMRetValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + TVMRetValue pod; + pod = kv.first; + return PackedFuncValueConverter::From(pod); + }(); + U new_value = [&]() { + TVMRetValue pod; + pod = kv.second; + return PackedFuncValueConverter::From(pod); + }(); + output.Set(new_key, new_value); } + return output; } }; @@ -2181,7 +2584,7 @@ struct PackedFuncValueConverter> { return opt.value(); } - if (auto opt = TryValueConverter(val)) { + if (auto opt = TryValueConverter(val)) { return opt.value(); } @@ -2192,10 +2595,10 @@ struct PackedFuncValueConverter> { << " but got " << ArgTypeCode2Str(val.type_code()); } - template - static Optional TryAsObjectRef(const TVMPODValue_& val) { - if (val.IsObjectRef()) { - return VType(val.AsObjectRef()); + template + static Optional TryAsObjectRef(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return VType(val.template AsObjectRef()); } else if constexpr (sizeof...(VarRest)) { return TryAsObjectRef(val); } else { @@ -2203,15 +2606,15 @@ struct PackedFuncValueConverter> { } } - template + template static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const InternalError&) { + } catch (const Error&) { } if constexpr (sizeof...(VarRest)) { - return TryValueConverter(val); + return TryValueConverter(val); } else { return NullOpt; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index d47ac94e067e..4c1d1fc1f3d2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,7 +113,15 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + return ret; } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 130aea32f844..6b3b9c31a645 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index d9b65dc8745c..28cb022151d2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1155,6 +1155,63 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm +namespace tvm { +namespace runtime { + +// Automatic conversion into PrimExpr, when called through the FFI. +// Automatic conversions into IntImm, Integer, and Bool are registered +// in "tvm/ir/expr.h", as they are currently in use outside of TIR. + +template <> +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + auto type_code = val.type_code(); + bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || + type_code == kTVMStr || val.template IsObjectRef(); + if (can_convert) { + return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); + } else { + return NullOpt; + } + } + + template + static tvm::tir::StringImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +template <> +struct PackedFuncValueConverter { + // Common rule for RetValue and ArgValue. Templated to ensure + // correct delegation to `operator std::string()` for either + // TVMArgValue or TVMRetValue. + template + static PrimExpr From(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + // Check against val.TryAsBool directly, to avoid the + // bounds-checking in PackedFuncValueConverter::TryFrom. + return tvm::Bool(opt.value()); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else { + return PrimExpr::FromObject_(val.template AsObjectRef()); + } + } +}; + +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 274ebd0a6558..1d218c6a7c61 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map& param_map); +PrimFunc Specialize(PrimFunc func, const Map>& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9b23973b6f8f..092bd52d5634 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -224,8 +224,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) = 0; + virtual ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 520e0e42ebbe..8f674eea2ec6 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,14 +60,36 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + + # Handle return values that subclass from both TVM objects and + # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) + # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # + # The `hasattr` check is done on the object's class, not the + # object itself, to avoid edge cases that can result in invalid + # error messages. If a C++ `LOG(FATAL) << nested_obj;` statement + # requires C++ to Python conversions in order to print + # `nested_obj`, then the `AttributeError` used internally by + # `hasattr` may overwrite the text being collected by + # `LOG(FATAL)`. By checking for the method on the class instead + # of the instance, we avoid throwing the `AttributeError`. + # if hasattr(type(obj), "__into_pynative_object__"): + # return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 5f3aa04914be..6dab1a5db1f4 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,6 +134,11 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + values[i].v_bool = arg + type_codes[i] = ArgTypeCode.BOOL elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -147,7 +152,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only taeks in bytearray. + # from_buffer only takes in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 38d3cd72b55d..45f36eafd78a 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -27,6 +27,7 @@ class TVMValue(ctypes.Union): _fields_ = [ ("v_int64", ctypes.c_int64), + ("v_bool", ctypes.c_bool), ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p), @@ -94,6 +95,7 @@ def _device_to_int64(dev): RETURN_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, @@ -104,6 +106,7 @@ def _device_to_int64(dev): C_TO_PY_ARG_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 69e1355f7d13..0f7e5fcae6bd 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -16,6 +16,7 @@ # under the License. from ..base import raise_last_ffi_error +from libcpp cimport bool as bool_t from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -38,7 +39,8 @@ cdef enum TVMArgTypeCode: kTVMBytes = 12 kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 - kTVMExtBegin = 15 + kTVMArgBool = 15 + kTVMExtBegin = 16 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: @@ -66,6 +68,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct TVMValue: int64_t v_int64 + bool_t v_bool double v_float64 void* v_handle const char* v_str diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 94a9310d7815..ff38cd3d0ec2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,7 +60,17 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # if hasattr(obj, '__into_pynative_object__'): + # return obj.__into_pynative_object__) + return obj + # return obj.__into_pynative_object__() class PyNativeObject: diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 3d1e87bf563d..7977f37d0be5 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -45,7 +45,7 @@ cdef int tvm_callback(TVMValue* args, tcode == kTVMModuleHandle or tcode == kTVMNDArrayHandle or tcode == kTVMObjectRefArg or - tcode > kTVMExtBegin): + tcode >= kTVMExtBegin): CHECK_CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: @@ -118,6 +118,11 @@ cdef inline int make_arg(object arg, ptr = arg._tvm_handle value[0].v_handle = (ptr) tcode[0] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + value[0].v_bool = arg + tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg tcode[0] = kInt @@ -209,6 +214,8 @@ cdef inline object make_ret(TVMValue value, int tcode): return make_ret_object(value.v_handle) elif tcode == kTVMNullptr: return None + elif tcode == kTVMArgBool: + return value.v_bool elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f148e26f3fcb..03dc18ea6e0b 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,7 +48,8 @@ class ArgTypeCode(object): BYTES = 12 NDARRAY_HANDLE = 13 OBJECT_RVALUE_REF_ARG = 14 - EXT_BEGIN = 15 + BOOL = 15 + EXT_BEGIN = 16 class TVMByteArray(ctypes.Structure): diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py index c2e74eb1935e..b76202a730a2 100644 --- a/python/tvm/driver/tvmc/registry.py +++ b/python/tvm/driver/tvmc/registry.py @@ -20,11 +20,23 @@ from tvm.driver.tvmc import TVMCException -# We can't tell the type inside an Array but all current options are strings so -# it can default to that. Bool is used alongside Integer but aren't distinguished -# between as both are represented by IntImm -INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} -INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} +# We can't tell the type inside an Array but all current options are +# strings so it can default to that. runtime.BoxBool is used to +# distinguish from runtime.BoxInt. +INTERNAL_TO_NATIVE_TYPE = { + "runtime.String": str, + "runtime.BoxBool": bool, + "runtime.BoxFloat": float, + "runtime.BoxInt": int, + "Array": str, +} +INTERNAL_TO_HELP = { + "runtime.String": " string", + "runtime.BoxBool": " bool", + "runtime.BoxInt": " int", + "runtime.BoxFloat": " float", + "Array": " options", +} def _generate_registry_option_args(parser, registry, name): diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 6f0a6dd7d155..6afb383c9f04 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -61,7 +61,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x.value for x in self.__getattr__(key)) + return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) def get_int(self, key): """Get a python int value of a key diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index c70ac2acc71b..263976fa98ff 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -20,7 +20,7 @@ import tvm._ffi -from ..runtime import Object, Scriptable, const, convert +from ..runtime import Object, Scriptable from . import _ffi_api from .base import Node, Span from .type import Type @@ -184,9 +184,6 @@ class Range(Node, Scriptable): def __init__( self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None ) -> None: - if end is None: - end = convert(begin) - begin = const(0, dtype=end.dtype, span=span) self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 6f76452a57b5..51d9a013d8b3 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,6 +28,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule +from tvm.script import tir as T from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -47,7 +48,7 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if isinstance(mod, PrimFunc): if not (mod.attrs and "global_symbol" in mod.attrs): mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) + mod = mod.with_attr("tir.noalias", T.bool(True)) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py index eb44696871eb..502d058ffdf6 100644 --- a/python/tvm/relax/op/statistical.py +++ b/python/tvm/relax/op/statistical.py @@ -195,7 +195,7 @@ def cumprod( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. @@ -213,9 +213,9 @@ def cumprod( Type of the returned array and of the accumulator in which the elements are computed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the product. If + true, the first element is excluded from the product. Returns ------- @@ -247,6 +247,9 @@ def cumprod( cumprod(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 1, 0, 0, 0, 0] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumprod(data, axis, dtype, exclusive) # type: ignore @@ -254,7 +257,7 @@ def cumsum( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. @@ -272,9 +275,9 @@ def cumsum( Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the sum. If + true, the first element is excluded from the sum. Returns ------- @@ -306,6 +309,9 @@ def cumsum( cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumsum(data, axis, dtype, exclusive) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 1ed16363b20a..4c670bbe74b2 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -171,11 +171,19 @@ def visit_call_(self, op: relax.Call) -> str: def display_attrs(attr_key): attr_val = op.attrs[attr_key] - # attrs can be strings but also other types; - # we want to wrap strings in quotes - # (__repr__ would work but it uses single quotes) - attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) - return f"{wrap_quotes(attr_key)}: {attr_str}" + + if isinstance(attr_val, str): + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_val = wrap_quotes(attr_val) + elif isinstance(attr_val, tvm.tir.IntImm): + if attr_val.dtype == "bool": + attr_val = bool(attr_val.value) + else: + attr_val = int(attr_val.value) + + return f"{wrap_quotes(attr_key)}: {attr_val}" fields["attrs"] = self.build_list( map(display_attrs, op.attrs.keys()), diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 71bf8509a63e..aba7ae912c54 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -139,14 +139,14 @@ def _check_well_formed(self, mod: IRModule): # Check function attrs if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm + mod.attrs[self.PARAM_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm + mod.attrs[self.STATE_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 9323bc40da69..e1cab4cbd53b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -97,6 +97,9 @@ def convert_to_expr(value: Any) -> Expr: if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) + if isinstance(value, float): + return PrimValue(tir.FloatImm("float64", value)) + tvm_value = convert_to_object(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 97d7cfa93c8d..199193f75939 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -76,7 +76,7 @@ def get_section_begin_coords(split: tvm.relay.Expr) -> List[int]: # 0 is the beginning of the first section. return [0] + list(indices_or_sections) split_axis_len = input_shape[split_axis].value - section_length = split_axis_len // indices_or_sections.value + section_length = split_axis_len // indices_or_sections return list(range(0, split_axis_len, section_length)) def callback( diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 6b9b311c83b5..dca7b995b22d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Gradient definitions for Relay operators""" +import tvm from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -383,6 +384,8 @@ def concatenate_grad(orig, grad): axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields] splits, cumsum = [], 0 for dim in axis_dims[:-1]: + if isinstance(dim, tvm.tir.IntImm): + dim = dim.value cumsum += dim splits.append(cumsum) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 93df67ff6b99..8bca72655491 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1057,10 +1057,10 @@ def split_shape_func(attrs, inputs, _): return [ _split_shape_func( inputs[0], - convert(i), - convert(indices_or_sections), - convert(param_is_indices), - convert(axis), + i, + indices_or_sections, + param_is_indices, + axis, ) for i in range(num_out) ] diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index dd04d613079b..c4eff3fcc9e0 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1630,10 +1630,10 @@ def __init__(self, func_body): def convert_indices_or_sections(self, indices_or_sections): # split_v if isinstance(indices_or_sections, tvm.ir.container.Array): - values = [i.value for i in indices_or_sections] + values = [int(i) for i in indices_or_sections] # split else: - values = indices_or_sections.value + values = int(indices_or_sections) return values def is_valid(self): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ef1cdb3afdd8..dd9c670e2a37 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,6 +18,8 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +from typing import Optional + from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make @@ -855,13 +857,14 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = list(shape.data.numpy()) - if isinstance(shape, Expr): + shape = shape.data.numpy() + shape = [_expr.IntImm(str(shape.dtype), int(value)) for value in shape] + elif isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) + if isinstance(shape, int): shape = [shape] - if isinstance(shape, (list, tuple)): - shape = list(shape) + return _make.broadcast_to(data, shape) @@ -1938,9 +1941,8 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -def dft(re_data, im_data, inverse=False): - """ - Computes the discrete Fourier transform of input (calculation along the last axis). +def dft(re_data, im_data, inverse: Optional[bool] = False): + """Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. Parameters @@ -1952,8 +1954,11 @@ def dft(re_data, im_data, inverse=False): N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. - inverse : bool + inverse : Optional[bool] + Whether to perform the inverse discrete fourier transform. + Providing None is equivalent to False, and is maintained for + compatibility. Returns ------- @@ -1961,7 +1966,11 @@ def dft(re_data, im_data, inverse=False): The Fourier Transform of the input (Real part). im_output : relay.Expr The Fourier Transform of the input (Imaginary part). + """ + if inverse is None: + inverse = False + return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 7ad838895c9f..6eef6ff3ffae 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -364,9 +364,8 @@ def split(expr, type_map): arg = expr.args[0] t = type_map[arg] attrs = {**expr.attrs} - if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): - num_split = attrs["indices_or_sections"].value - attrs["indices_or_sections"] = num_split + if isinstance(attrs["indices_or_sections"], int): + num_split = attrs["indices_or_sections"] else: num_split = len(attrs["indices_or_sections"]) + 1 return [expr, TupleAffineType([t] * num_split)] diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index f182cd9bfd2f..301f0ef66286 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures -from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple +from .container import String, ShapeTuple # , BoxBool +from .object_generic import convert_to_object, convert, const from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..f1a0706a387d 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -172,3 +172,41 @@ def __eq__(self, other): return False return True + + +# @tvm._ffi.register_object("runtime.BoxBool") +# class BoxBool(Object): +# """A boolean wrapped as a tvm Object + +# Parameters +# ---------- +# value: bool + +# The value to hold +# """ + +# def __init__(self, value: bool): +# # Convert to int to avoid an infinite recursion, because +# # BoxBool may be constructed in _make_tvm_args, and calling +# # the packed func `_ffi_api.BoxBool` internally calls +# # `_make_tvm_args`. +# self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) + +# def __into_pynative_object__(self) -> bool: +# return self.value + +# @property +# def value(self) -> bool: +# """Unwrap the boxed value. + +# This is implemented explicitly rather than using the usual +# PackedFunc handling or AttrVisitor mechanics for two reasons. +# First, because the PackedFunc handling would require ambiguous +# representations between `True`/`1` and `False`/`0`. Second, +# because the boxing/unboxing must be available in +# `libtvm_runtime.so`, and AttrVisitor is only available in +# `libtvm.so`. +# """ +# unboxed_bool = _ffi_api.UnBoxBool(self) +# assert unboxed_bool is not None +# return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 887c2faaeb2b..20909c53c787 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,65 +38,62 @@ def asobject(self): ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject) -def convert_to_object(value, span=None): +def convert_to_object(value): """Convert a Python value to corresponding object type. + Type conversions performed by this function must *only* produce + types that are supported by `libtvm_runtime.so`. This function + must be usable in environments where only TVM runtime support is + present. Automatic conversions to compile-time representations + (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as + part of this conversion, as these types are not available in + `libtvm_runtime.so`. + Parameters ---------- value : str The value to be inspected. - span : Optional[Span] - The location of this itervar in the source code. - Returns ------- obj : Object The corresponding object value. + """ + if isinstance(value, ObjectTypes): return value - if isinstance(value, bool): - return const(value, "uint1x1", span=span) - if isinstance(value, Number): - return const(value, span=span) - if isinstance(value, string_types): + elif isinstance(value, (bool, int, float)): + return value + elif isinstance(value, string_types): return _ffi_api.String(value) - if isinstance(value, (list, tuple)): - value = [convert_to_object(x) for x in value] + elif isinstance(value, (list, tuple)): + # The call to _ffi_api.Array will convert its own arguments, + # so we don't need to apply any explicit conversions here. return _ffi_api.Array(*value) - if isinstance(value, dict): - vlist = [] - for item in value.items(): - if ( - not isinstance(item[0], ObjectTypes) - and not isinstance(item[0], string_types) - and not isinstance(item[0], Number) - ): - raise ValueError("key of map must already been a container type") - vlist.append(convert_to_object(item[0])) - vlist.append(convert_to_object(item[1])) + elif isinstance(value, dict): + if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): + raise ValueError("key of map must already been a container type") + + vlist = [kv for item in value.items() for kv in item] return _ffi_api.Map(*vlist) - if isinstance(value, ObjectGeneric): + elif isinstance(value, ObjectGeneric): return value.asobject() - if callable(value): + elif callable(value): return convert_to_tvm_func(value) - if value is None: + elif value is None: return None - - raise ValueError(f"don't know how to convert type {type(value)} to object") + else: + raise TypeError(f"don't know how to convert type {type(value)} to object") -def convert(value, span=None): +def convert(value): """Convert value to TVM object or function. Parameters ---------- value : python value - span : Optional[Span] - The location of this statement in the source code. - Returns ------- tvm_val : Object or Function @@ -107,29 +104,29 @@ def convert(value, span=None): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ - return convert_to_object(value, span=span) + + return convert_to_object(value) def _scalar_type_inference(value): if hasattr(value, "dtype"): - dtype = str(value.dtype) + return str(value.dtype) elif isinstance(value, bool): - dtype = "bool" + return "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - dtype = "float32" + return "float32" else: - dtype = "float64" + return "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - dtype = "int32" + return "int32" else: - dtype = "int64" + return "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") - return dtype def const(value, dtype=None, span=None): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e545bc3a5e53..3107354ac353 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -536,6 +536,8 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ value = self.eval_expr(node.value) + if value is None: + self.report_error(node, "Expression to be returned must be a PrimExpr") T.evaluate(tvm.tir.ret(value)) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 462066106a9d..948a0d7665ff 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -96,7 +96,7 @@ def _allocate_tensor(func_id, args): ) shape = args[0] for i in shape: - _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression") + _internal_assert(isinstance(i, (_expr.PrimExpr, int)), "The shape should be an expression") if n > 1: _internal_assert(isinstance(args[1], str), "The data type should be an str") _internal_assert( @@ -131,9 +131,11 @@ def len(func_id, args): def _cast(func_id, args): _internal_assert( - args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), - "Only one expression can be cast", + args.__len__() == 1, + f"Casting to {func_id} only supports a single argument", ) + # The FFI can handle any conversion of `args[0]` into PrimExpr, if + # required. return _expr.Cast(func_id, args[0]) @@ -145,9 +147,7 @@ def _cast(func_id, args): def ceil_div(func_id, args): _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") _internal_assert(args.__len__() == 2, "2 arguments expected for division!") - _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div") - _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div") - a, b = args[0], args[1] + a, b = args return (a + b - 1) // b diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 846ef818ea54..bd5a060cd01c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -279,7 +279,7 @@ def visit_Num(self, node): return tvm.runtime.const(node.n, dtype) def visit_NameConstant(self, node): - return tvm.runtime.convert(node.value) + return tvm.tir.const(node.value) def visit_AugAssign(self, node): buf = self.visit(node.target) @@ -376,7 +376,7 @@ def visit_Subscript(self, node): args = [args] arr = self.visit(node.value) - if isinstance(arr, Array): + if isinstance(arr, (Array, list, tuple)): for i in args: if isinstance(i, numbers.Integral): arr = arr[i] diff --git a/python/tvm/te/hybrid/utils.py b/python/tvm/te/hybrid/utils.py index f653b3e83d8b..a515938fa524 100644 --- a/python/tvm/te/hybrid/utils.py +++ b/python/tvm/te/hybrid/utils.py @@ -33,9 +33,9 @@ # pylint: disable=invalid-name -np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) -tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm) +np_arg_types = (numpy.ndarray, *numeric_types) +tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr, *numeric_types, list, tuple, str) +halide_imm_types = (_expr.IntImm, _expr.FloatImm, *numeric_types) def _internal_assert(cond, err): @@ -91,19 +91,13 @@ def replace(op): def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" - if isinstance(args[0], tvm_arg_types): - for elem in args[1:]: - _internal_assert( - isinstance(elem, tvm_arg_types), - f"Expecting a Var, Tensor or ConstExpr instance but {type(elem)} get!", - ) + if all(isinstance(elem, tvm_arg_types) for elem in args): return True - - _internal_assert( - isinstance(args[0], np_arg_types), f"Expect a numpy type but {type(args[0])} get!" - ) - for elem in args[1:]: - _internal_assert( - isinstance(elem, np_arg_types), f"Expect a numpy type but {type(elem)} get!" + elif all(isinstance(elem, np_arg_types) for elem in args): + return False + else: + raise ValueError( + f"Expected arguments to be entirely TVM types, " + f"or entirely numpy types, " + f"but received {[type(elem) for elem in args]}" ) - return False diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index dc2c67849925..64a282dcf755 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -53,7 +53,6 @@ def placeholder(shape, dtype=None, name="placeholder"): tensor: Tensor The created tensor """ - shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape dtype = "float32" if dtype is None else dtype return _ffi_api.Placeholder(shape, dtype, name) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index d435e821acf3..930667242e29 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -64,16 +64,7 @@ def __call__(self, *indices): f"Need to provide {ndim} index in tensor but {len(indices)} was provided" ) indices = convert_to_object(indices) - args = [] - for x in indices: - if isinstance(x, _expr.PrimExpr): - args.append(x) - elif isinstance(x, _expr.IterVar): - args.append(x.var) - else: - raise ValueError("The indices must be expression") - - return _expr.ProducerLoad(self, args) + return _expr.ProducerLoad(self, indices) def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index bcfbe6575d52..0c8048d24d8b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -21,6 +21,7 @@ from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout +from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index c78bb9e7ecd0..37976394f831 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -41,6 +41,10 @@ from .buffer import Buffer, DataProducer +def convert(expr) -> PrimExpr: + return _ffi_api.convert(expr) + + def div_ambiguity_error() -> RuntimeError: return RuntimeError( "TVM supports multiple types of integer divisions, " diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 50de995a9145..777d46ec7b0d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, convert, const +from tvm.runtime import ObjectGeneric, const from tvm.ir import container as _container from . import stmt as _stmt @@ -107,7 +107,9 @@ def __getitem__(self, index): def __setitem__(self, index, value): index = self._normalize_index(index) - value = convert(value) + if isinstance(value, (int, bool, float)): + value = tvm.tir.const(value) + value_element = value.dtype.split("x", maxsplit=1)[0] content_element = self._content_type.split("x", maxsplit=1)[0] if value_element != content_element: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0bc299e403c5..8d9647b60049 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -19,13 +19,14 @@ from typing import Any, Optional, Union import tvm._ffi +from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import const, convert +from tvm.runtime import const from . import _ffi_api from .buffer import Buffer -from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var +from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var def _pack_buffer(buf, span=None): @@ -181,7 +182,7 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, convert(args), span) + return Call(dtype, func_name, args, span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -206,9 +207,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span - ) + return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) def call_extern(dtype, func_name, *args, span=None): @@ -233,9 +232,7 @@ def call_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span - ) + return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span) def call_llvm_intrin(dtype, name, *args, span=None): @@ -1832,13 +1829,10 @@ def dp4a(vec1, vec2, acc=0): call : PrimExpr The call expression. """ - vec1 = convert(vec1) - vec2 = convert(vec2) - acc = convert(acc) return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) -def ret(val): +def ret(val, span=None): """Create a tir return expression Parameters @@ -1846,14 +1840,16 @@ def ret(val): val : Expr The returned tir expression, whose data type is int, float or void pointer. + span : Optional[Span] + The location of this operator in the source code. + Returns ------- ret : PrimExpr The return expression """ - val = convert(val) - return call_intrin(val.dtype, "tir.ret", val) + return _ffi_api.ret(val, span) def any(*args, span=None): @@ -2038,7 +2034,7 @@ def exp(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp", x) @@ -2055,7 +2051,7 @@ def exp2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp2", x) @@ -2072,7 +2068,7 @@ def exp10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp10", x) @@ -2089,7 +2085,7 @@ def erf(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.erf", x) @@ -2106,7 +2102,7 @@ def tanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tanh", x) @@ -2123,7 +2119,7 @@ def sigmoid(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sigmoid", x) @@ -2140,7 +2136,7 @@ def log(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log", x) @@ -2157,7 +2153,7 @@ def log2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log2", x) @@ -2174,7 +2170,7 @@ def log10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log10", x) @@ -2191,7 +2187,7 @@ def log1p(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log1p", x) @@ -2208,7 +2204,7 @@ def tan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tan", x) @@ -2225,7 +2221,7 @@ def cos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cos", x) @@ -2242,7 +2238,7 @@ def cosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cosh", x) @@ -2259,7 +2255,7 @@ def acos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acos", x) @@ -2276,7 +2272,7 @@ def acosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acosh", x) @@ -2293,7 +2289,7 @@ def sin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sin", x) @@ -2310,7 +2306,7 @@ def sinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sinh", x) @@ -2327,7 +2323,7 @@ def asin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asin", x) @@ -2344,7 +2340,7 @@ def asinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asinh", x) @@ -2361,7 +2357,7 @@ def atan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atan", x) @@ -2378,7 +2374,7 @@ def atanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atanh", x) @@ -2398,8 +2394,8 @@ def atan2(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.atan2", x1, x2) @@ -2416,7 +2412,7 @@ def sqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sqrt", x) @@ -2433,7 +2429,7 @@ def rsqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.rsqrt", x) @@ -2679,8 +2675,8 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore @@ -2700,8 +2696,8 @@ def hypot(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore @@ -2721,8 +2717,8 @@ def copysign(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore @@ -2742,8 +2738,8 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore @@ -2862,7 +2858,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def pow(x, y, span=None): @@ -2884,7 +2880,7 @@ def pow(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def popcount(x): @@ -2900,7 +2896,7 @@ def popcount(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.popcount", x) @@ -3032,8 +3028,8 @@ def fmod(x, y): z : PrimExpr The result. """ - x = convert(x) - y = convert(y) + x = tir.convert(x) + y = tir.convert(y) return call_intrin(x.dtype, "tir.fmod", x, y) @@ -3067,7 +3063,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore + return _ffi_api._OpIfThenElse(cond, t, f, span) # type: ignore def div(a, b, span=None): @@ -3314,34 +3310,23 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None, init=None): code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - expr = convert(expr) + expr = tir.convert(expr) if init is not None: - init = convert(init) + init = tir.convert(init) if isinstance(expr, Array): size = len(expr) - larr = [] - rarr = [] + lhs = [] + rhs = [] dtypes = [] for i in range(size): dtype = expr[i].dtype dtypes.append(dtype) lname = code.co_varnames[0] + "_" + str(i) - larr.append(Var(lname, dtype)) + lhs.append(Var(lname, dtype)) rname = code.co_varnames[1] + "_" + str(i) - rarr.append(Var(rname, dtype)) - if init is not None: - init = convert(init) - assert isinstance(init, Array) - assert len(init) == size - for init_i in range(size): - init_i = convert(init_i) - assert isinstance( - init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm) - ) - else: - init = convert([]) - lhs = convert(larr) - rhs = convert(rarr) + rhs.append(Var(rname, dtype)) + if init is None: + init = [] result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: @@ -3352,22 +3337,18 @@ def _make_reduce(expr, axis, where=None, init=None): rvar = Var(code.co_varnames[1], dtype) result = [fcombine(lvar, rvar)] id_elem = [fidentity(dtype)] - lhs = convert([lvar]) - rhs = convert([rvar]) - expr = convert([expr]) + lhs = [lvar] + rhs = [rvar] + expr = [expr] if init is not None: - assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)) - init = convert([init]) - result = convert(result) - id_elem = convert(id_elem) + init = [init] combiner = CommReducer(lhs, rhs, result, id_elem) - axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) + if not isinstance(axis, (list, tuple, tvm.ir.Array)): + axis = [axis] if where is None: - where = convert(True) + where = tir.convert(True) if init is None: - outputs = tuple( - tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size) - ) + outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size)) else: outputs = tuple( tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index cb8d5ce9973e..85377560f1fc 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -39,17 +39,20 @@ def _json_from_tvm(obj): if obj is None: return None - if isinstance(obj, Array): + elif isinstance(obj, (bool, int, float, str)): + return obj + elif isinstance(obj, Array): return [_json_from_tvm(i) for i in obj] - if isinstance(obj, Map): + elif isinstance(obj, Map): return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} - if isinstance(obj, String): + elif isinstance(obj, String): return str(obj) - if isinstance(obj, (IntImm, FloatImm)): + elif isinstance(obj, (IntImm, FloatImm)): return obj.value - if isinstance(obj, IndexMap): + elif isinstance(obj, IndexMap): return save_json(obj) - raise TypeError("Not supported type: " + str(type(obj))) + else: + raise TypeError("Not supported type: " + str(type(obj))) @_register_object("tir.Trace") diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index bf6a9c75516f..cc1a28b9dee0 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -468,7 +468,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype - use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value + use_scalable_vectors = bool(out.op.attrs["use_scalable_vectors"]) tile_M, tile_K = arm_utils.get_tiling_A(False, in_type) tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 83b000a4b9bb..0a7acfa50444 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -295,15 +295,11 @@ def batch_matmul_int8( # pad for _dp4a vectorize pad_x = te.compute( (XB, M, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= XK, tvm.tir.const(0, x.dtype), x[b, i, j]), ) pad_y = te.compute( (YB, N, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= YK, tvm.tir.const(0, y.dtype), y[b, i, j]), ) out = te.compute( diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index 8d59c2a035a9..b98d9c102baa 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -48,7 +48,7 @@ pub struct ModuleNode { crate::external! { #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> i32; + fn runtime_enabled(target: CString) -> bool; #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; @@ -121,8 +121,7 @@ impl Module { /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); - let enabled = runtime_enabled(target).unwrap(); - enabled != 0 + runtime_enabled(target).unwrap() } /// Returns the underlying module handle. diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d..2c1f7db6adb0 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -73,6 +73,7 @@ macro_rules! TVMPODValue { Int(i64), UInt(i64), Float(f64), + Bool(bool), Null, DataType(DLDataType), String(*mut c_char), @@ -95,6 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -117,6 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), @@ -263,6 +266,7 @@ macro_rules! impl_pod_value { impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(Bool, bool, [bool]); impl_pod_value!(DataType, DLDataType, [DLDataType]); impl_pod_value!(Device, DLDevice, [DLDevice]); @@ -380,37 +384,6 @@ impl TryFrom for std::ffi::CString { } } -// Implementations for bool. - -impl<'a> From<&bool> for ArgValue<'a> { - fn from(s: &bool) -> Self { - (*s as i64).into() - } -} - -impl From for RetValue { - fn from(s: bool) -> Self { - (s as i64).into() - } -} - -impl TryFrom for bool { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> bool, - |RetValue::Int(val)| { !(val == 0) }) - } -} - -impl<'a> TryFrom> for bool { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) - } -} - impl From<()> for RetValue { fn from(_: ()) -> Self { RetValue::Null diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e03d4302c89f..82e439cddbc2 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -554,9 +554,19 @@ class FlopEstimator : public ExprFunctor { if (auto pop = op.as()) { if (pop->attrs.count("FLOP")) { // Use user-provided FLOP - auto pint = pop->attrs["FLOP"].as(); - ICHECK(pint != nullptr); - ret += pint->value; + ObjectRef annotation = pop->attrs["FLOP"]; + auto value = [&]() -> int64_t { + if (auto runtime_int = annotation.as()) { + return runtime_int->value; + } else if (auto int_imm = annotation.as()) { + return int_imm->value; + } else { + LOG(FATAL) << "FLOP annotation must be an integer, " + << "but was an object of type " << annotation->GetTypeKey(); + } + }(); + + ret += value; } else { // Estimate by parsing the compute body double num_element = AxisLengthProd(pop->axis); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 862e593c9dd3..0bf6da255d2a 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -482,7 +482,8 @@ std::vector> RuleCustomSketch::Apply(const SketchPolicyNod std::vector> ret; for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); - auto next = item[1].as(); + auto next = item[1].as(); + ICHECK(next); ret.emplace_back(Downcast(item[0]), next->value); } return ret; diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 76fb77dd9527..cc6b0ab23756 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -101,7 +101,7 @@ inline int OperationToStage(const te::Operation& op, const State& state) { /*! \brief Get an integer from a tvm str Map. */ inline int GetIntParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); + auto pint = attr_dict[key].as(); ICHECK(pint != nullptr); return pint->value; } @@ -109,7 +109,7 @@ inline int GetIntParam(const Map& attr_dict, const std::strin /*! \brief Get a double from a tvm str Map. */ inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); + auto pdouble = attr_dict[key].as(); ICHECK(pdouble != nullptr); return pdouble->value; } @@ -120,10 +120,12 @@ inline std::string GetStringParam(const Map& attr_dict, const const auto& target = attr_dict[key]; if (auto pstr = target.as()) { return pstr->value; + } else if (auto pstr = target.as()) { + return pstr->data; + } else { + LOG(FATAL) << "Could not convert object " << target << " of type " << target->GetTypeKey() + << " to string"; } - auto pstr = target.as(); - ICHECK(pstr != nullptr); - return pstr->data; } /*! \brief Get a iterator name set from a tvm str Map. */ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 289c1b79fd66..708fb56c9851 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -100,8 +100,17 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { const ObjectRef& value = doc->value; if (!value.defined()) { output_ << "\"\""; + } else if (const auto* runtime_int = value.as()) { + output_ << runtime_int->value; } else if (const auto* int_imm = value.as()) { output_ << int_imm->value; + } else if (const auto* runtime_float = value.as()) { + output_.precision(config_.float_precision); + if (std::isinf(runtime_float->value) || std::isnan(runtime_float->value)) { + output_ << '"' << runtime_float->value << '"'; + } else { + output_ << runtime_float->value; + } } else if (const auto* float_imm = value.as()) { output_.precision(config_.float_precision); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 7e96c657a711..99be910bd70a 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -33,6 +33,10 @@ namespace msc { LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { if (obj.as()) { return LiteralDoc::Str(Downcast(obj), NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Int(ptr->value, NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Float(ptr->value, NullOpt); } else if (obj.as()) { return LiteralDoc::Int(Downcast(obj)->value, NullOpt); } else if (obj.as()) { diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index f58f95ae53b0..5fcbe924ae1c 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -263,6 +263,10 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = ""; } else if (obj.as()) { obj_string = Downcast(obj); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 105ac063e0ea..1e576bc91002 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -171,9 +171,10 @@ Array CreatePassList(bool disable_loop_partition) { // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { - const IntImmNode* phase_num = phase_pass[0].as(); + auto phase_num = phase_pass[0].as(); ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " + << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index f197ac4416fa..08e7ffc5bf59 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -31,6 +31,91 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } +namespace { + +/* \brief Normalize attributes from runtime types to Relax IR types + * + * While conversion from `tvm::runtime` types to compile-time IR + * types usually occurs as part of FFI conversions, the attributes + * are not converted, as they are stored in a `Map`. While this is required to allow attribute values to + * contain `ObjectRef` instances that are not IR expressions, the + * conversion should still be applied when possible. + * + * \param obj The IR attribute value to be normalized + * + * \return The normalized attribute value + */ +ObjectRef NormalizeAttr(ObjectRef obj) { + if (auto dict_attrs = obj.as()) { + auto new_dict = Downcast>(NormalizeAttr(dict_attrs->dict)); + if (new_dict.same_as(dict_attrs->dict)) { + return obj; + } else { + return DictAttrs(new_dict); + } + } else if (auto runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (auto runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map([](const ObjectRef& inner) { return NormalizeAttr(inner); }); + } else if (auto opt_map = obj.as>()) { + auto map = opt_map.value(); + + Map updates; + for (const auto& [key, inner] : map) { + auto new_inner = NormalizeAttr(inner); + if (!new_inner.same_as(inner)) { + updates.Set(key, new_inner); + } + } + for (const auto& [key, new_inner] : updates) { + map.Set(key, new_inner); + } + + return map; + + } else { + return obj; + } +} +} // namespace + +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { + if (new_attrs.empty()) { + return attrs; + } + + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + + for (const auto& [key, value] : new_attrs) { + attr_dict.Set(key, NormalizeAttr(value)); + } + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.Set(key, NormalizeAttr(value)); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.erase(key); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; @@ -43,11 +128,15 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un dict.Set(key, val.operator PrimExpr()); } } + + dict = Downcast>(NormalizeAttr(dict)); } Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { + dict = Downcast>(NormalizeAttr(dict)); + ObjectPtr n = make_object(); n->dict = std::move(dict); data_ = std::move(n); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 596805f74b24..ded046eafc5d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -47,6 +47,12 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { if (auto opt = ref.as()) { return tir::StringImm(opt.value()); } + if (auto opt = ref.as()) { + return Bool(opt.value()); + } + if (auto opt = ref.as()) { + return Integer(opt.value()); + } if (const auto* buffer_region = ref.as()) { Array indices; indices.reserve(buffer_region->region.size()); @@ -155,9 +161,14 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Range(args[0], args[1], args[2]); -}); +TVM_REGISTER_GLOBAL("ir.Range") + .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { + if (end.defined()) { + return Range(begin, end.value(), span); + } else { + return Range(IntImm(begin->dtype, 0), begin, span); + } + }); TVM_REGISTER_NODE_TYPE(RangeNode); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index dc67822411c5..f0b879acbc03 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -107,43 +107,42 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, uint32_t value_type_index) { + void Register(std::string key, uint32_t value_type_index, + std::function legalization) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_index = value_type_index; info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + info.legalization = legalization; key2vtype_[key] = info; } // Trying to validate and legalize a config. void Legalize(Map* config) { std::vector> update; - auto* reflection = ReflectionVTable::Global(); - - for (auto kv : *config) { - auto it = key2vtype_.find(kv.first); + for (auto [key, obj] : *config) { + auto it = key2vtype_.find(key); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + os << "AttributeError: Invalid config option \'" << key << "\' candidates are:"; int counter = 0; - for (const auto& kv : key2vtype_) { + for (const auto& [key, obj] : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; - os << kv.first; + os << key; } LOG(FATAL) << os.str(); } const auto& info = it->second; - ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; - if (kv.second->IsInstance::ContainerType>()) { - ObjectRef converted = - reflection->CreateObject(info.type_key, Downcast>(kv.second)); - update.emplace_back(kv.first, converted); - } else { - if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { - LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " - << info.type_key << " but get " << kv.second->GetTypeKey(); - } + + ICHECK(obj.defined()) << "AttributeError: " << key << " is None"; + + ICHECK(info.legalization) << "AttributeError: " + << "Config option \'" << key + << "\' was defined without a legalization function."; + auto legalized = info.legalization(obj); + if (!legalized.same_as(obj)) { + update.emplace_back(key, legalized); } } for (auto&& kv : update) { @@ -170,13 +169,15 @@ class PassConfigManager { struct ValueTypeInfo { std::string type_key; uint32_t type_index; + std::function legalization; }; std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { - PassConfigManager::Global()->Register(key, value_type_index); +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization) { + PassConfigManager::Global()->Register(key, value_type_index, legalization); } Map> PassContext::ListConfigs() { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 416753871244..ce025540e496 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -39,8 +39,14 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { } else { os << int_imm->value; } + } else if (const auto* runtime_bool = json_obj.as()) { + os << (runtime_bool->value ? "true" : "false"); + } else if (const auto* runtime_int = json_obj.as()) { + os << runtime_int->value; } else if (const auto* float_imm = json_obj.as()) { os << std::setprecision(20) << float_imm->value; + } else if (const auto* runtime_float = json_obj.as()) { + os << std::setprecision(20) << runtime_float->value; } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { @@ -165,7 +171,7 @@ class JSONTokenizer { std::string to_parse(st, cur_); if (!is_float) { try { - *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))}; + *token = Token{TokenType::kInteger, runtime::Int(std::stoll(to_parse))}; } catch (const std::invalid_argument& e) { LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse << ". Details: " << e.what() << ". Switching to std::stod now."; @@ -178,7 +184,7 @@ class JSONTokenizer { } if (is_float) { try { - *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))}; + *token = Token{TokenType::kFloat, runtime::Float(std::stod(to_parse))}; } catch (const std::invalid_argument& e) { LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse << ". Details: " << e.what(); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 53f680f0a666..63af4a684567 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -192,7 +192,9 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, try { const ArrayNode* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); - workload = workloads[Downcast(arr->at(0)).IntValue()]; + int64_t workload_index = Downcast(arr->at(0)); + ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); + workload = workloads[workload_index]; records[task_id] = TuningRecord::FromJSON(arr->at(1), workload); } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index f5d89a85092b..5b3e6d251d56 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -137,7 +137,7 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + int decision = Downcast(trace->decisions[GetRef(sample_inst)]); std::vector probs = support::AsVector(Downcast>(sample_inst->attrs[1])); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index ea4e81c16f0c..a78b829e34ab 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -129,13 +129,13 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { ICHECK_EQ(inst->attrs.size(), 2); - std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + std::vector probs = support::AsVector( + Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; } - const auto* d = TVM_TYPE_AS(decision, IntImmNode); + const auto* d = TVM_TYPE_AS(decision, runtime::Int::ContainerType); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 7bbf00343af3..36dc57d80e66 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -114,9 +114,9 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = GetRef(sample_inst); candidate->decision = - Downcast(trace->decisions[GetRef(sample_inst)])->value; - candidate->probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = support::AsVector( + Downcast>(sample_inst->attrs[1])); return true; } diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index b651b1f401cb..110cae96cb53 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -34,11 +34,11 @@ using namespace tvm::tir; std::function MakeFactorSampler(Schedule sch, Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { - extents.push_back(extent); + extents.push_back(runtime::Int(extent->value)); } } int n = extents.size(); @@ -48,7 +48,7 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + Array probs(n, runtime::Float(1.0 / n)); return sch->SampleCategorical(extents, probs); }; } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index e8d821636fd3..4a304cefa6bb 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -73,7 +73,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); + Array probs(n_candidate, 1.0 / n_candidate); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -267,7 +267,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + Array thread_extents; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("max_threads_per_block", &max_threads_per_block); @@ -279,8 +279,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { - for (const Integer& extent : thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const auto& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } ObjectPtr n = make_object(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index bcaf4343e256..2979e4229bdd 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -383,9 +383,8 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = - (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), - Array(n, FloatImm(DataType::Float(64), prob))); + tir::ExprRV vector_load_len = (*sch)->SampleCategorical( + support::AsArray(valid_vector_lens), Array(n, prob)); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 045aa85b73ad..8ea2c2d1c6c3 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -68,7 +68,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, FloatImm(DataType::Float(64), prob)); + Array probs(n, runtime::Float(prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -102,7 +102,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -122,7 +122,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + Array unroll_max_steps, bool unroll_explicit) { ObjectPtr n = make_object(); n->max_jobs_per_core = max_jobs_per_core; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 3be264332461..83f5d073cb32 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -79,7 +79,7 @@ Array ScheduleRule::DefaultLLVM() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -126,7 +126,7 @@ Array ScheduleRule::DefaultX86(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -158,11 +158,11 @@ Array ScheduleRule::DefaultCUDA() { /*require_ordered=*/false, /*disallow_op=*/Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, @@ -297,7 +297,7 @@ Array ScheduleRule::DefaultHexagon() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } @@ -410,7 +410,7 @@ Array ScheduleRule::DefaultARM(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ceb0356cbcfe..28c45ea7455d 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -424,13 +424,22 @@ inline Array AsFloatArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), int_imm->value)); - } else if (const auto* float_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), float_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); - } + auto float_value = [&]() -> double { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else if (const auto* float_imm = elem.as()) { + return float_imm->value; + } else if (const auto* runtime_float = elem.as()) { + return runtime_float->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " + << elem->GetTypeKey(); + } + }(); + + results.push_back(FloatImm(DataType::Float(32), float_value)); } return results; } @@ -446,11 +455,16 @@ inline Array AsIntArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(Integer(int_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); - } + auto int_value = [&]() -> int64_t { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); + } + }(); + results.push_back(Integer(int_value)); } return results; } diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc new file mode 100644 index 000000000000..86596fb5ce29 --- /dev/null +++ b/src/node/boxed_primitive.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file node/boxed_primitive.cc + * + * \brief Reflection utilities for runtime-supported classes + * + * The fundamental support for boxing and unboxing of primitives + * during FFI calls is implemented in runtime/boxed_primitive.cc. In + * addition, boxed primitives may be registered with compile-time + * utilities (e.g. reflection, JSON import/export) that can provide + * additional functionality and improved debugging ability. However, + * neither these compile-time utilities nor any registration of + * `Box` into the compile-time utilities should be included as + * part of `libtvm_runtime.so`. + * + * This file contains the registration of the `libtvm_runtime.so` + * class `Box` for utilities that are contained in `libtvm.so`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime_ext { + +using runtime::Box; +using runtime::BoxNode; + +/* \brief Compile-time extension trait for runtime types + * + * Extends the use of boxed primitive during TVM's compilation step. + * + * Most TVM classes define these functions as part of the class + * definition. However, the boxed primitives must be usable at + * runtime, and so the class definition may only refer to types that + * are present in `libtvm_runtime.so`. + */ +template +struct BoxNodeCompileTimeTraits { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { + hash_reduce(node->value); + } + + static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, + SEqualReducer equal) { + return equal(lhs->value, rhs->value); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + int64_t value = std::atoll(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + int64_t value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + if (blob == "true") { + return make_object>(true); + } else if (blob == "false") { + return make_object>(false); + } else { + LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; + } + }) + .set_repr_bytes([](const Object* n) -> std::string { + bool value = GetRef(n).as>().value()->value; + if (value) { + return "true"; + } else { + return "false"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + double value = std::atof(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + double value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +} // namespace runtime_ext + +} // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 6e7d82ee4a59..b8918b4ea48c 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -57,7 +57,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -81,16 +81,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -107,13 +107,13 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } if (auto v = config_dict.Get("show_all_struct_info")) { - n->show_all_struct_info = Downcast(v)->value; + n->show_all_struct_info = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 379a75f6109b..614669a412d0 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -65,6 +65,22 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, return fsequal_reduce_[tindex](self, other, equal); } +namespace { +ObjectPath GetAttrPath(const ObjectRef& obj, const void* attr_address, const ObjectPath& path) { + if (obj->IsInstance() || + obj->IsInstance() || + obj->IsInstance()) { + // Special case for containers that contain boxed primitives. The + // "value" attribute containing the boxed value should not be part + // of the reported mismatched path. + return path; + } else { + Optional attr_key = GetAttrKeyByAddress(obj.get(), attr_address); + return path->Attr(attr_key); + } +} +} // namespace + struct SEqualReducer::PathTracingData { ObjectPathPair current_paths; ObjectRef lhs_object; @@ -72,10 +88,9 @@ struct SEqualReducer::PathTracingData { Optional* first_mismatch; ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { - Optional lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs); - Optional rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs); - return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key), - current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = GetAttrPath(lhs_object, &lhs, current_paths->lhs_path); + ObjectPath rhs_attr_path = GetAttrPath(rhs_object, &rhs, current_paths->rhs_path); + return ObjectPathPair(lhs_attr_path, rhs_attr_path); } }; @@ -98,13 +113,12 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { /* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { - Optional lhs_attr_key = - GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); - Optional rhs_attr_key = - GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); - *tracing_data->first_mismatch = - ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key), - tracing_data->current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = + GetAttrPath(tracing_data->lhs_object, lhs_address, tracing_data->current_paths->lhs_path); + ObjectPath rhs_attr_path = + GetAttrPath(tracing_data->rhs_object, rhs_address, tracing_data->current_paths->rhs_path); + + *tracing_data->first_mismatch = ObjectPathPair(lhs_attr_path, rhs_attr_path); } } @@ -200,7 +214,6 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } // Slow path: tracing object paths for better error reporting - ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 334e6e5c9a62..1c795594629e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,6 +45,7 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; +namespace { // Helper function to get the function name of the registered packed function implementation of // relax operator. FCallPacked GetPackedFuncName(const Call& call) { @@ -57,6 +58,7 @@ FCallPacked GetPackedFuncName(const Call& call) { } return {}; } +} // namespace /*! * \brief A class to generate VM executable for Relax functions. diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index dd34bc63bb31..5e6a1c3f8442 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,6 +44,21 @@ namespace relax_vm { using vm::VMFuncInfo; +namespace { +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} +} // namespace + /*! * \brief A class to generate VMTIR for Relax functions. * @@ -232,7 +247,14 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - if (call_node->op == call_builtin_with_ctx_op_) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + FCallPacked name = GetPackedFuncName(call); + if (name.size()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + EmitCallPacked(name, VisitArray(call->args), dst_reg); + } else if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg); @@ -260,10 +282,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - // turn ndarray cond value into scalar. - cond_value = tir::Cast(DataType::Bool(), - tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), - {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value}); tir::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index fd6fea6e703c..7aca1470aee4 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { +Expr full(Variant> shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 989eaa12fdbf..6e7c8255238a 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype); +Expr full(Variant> shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 07c90756bf90..2b1c6eafb652 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -654,7 +654,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -747,7 +747,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { return ShapeExpr(array_ref); } -Expr reshape(Expr x, ObjectRef shape) { +Expr reshape(Expr x, Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -812,7 +812,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, ObjectRef indices_or_sections, int axis) { +Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 32aa10776894..68622f1359e0 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -90,7 +90,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, ObjectRef shape); +Expr reshape(Expr x, Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -105,7 +105,7 @@ Expr reshape(Expr x, ObjectRef shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, ObjectRef indices_or_sections, int axis); +Expr split(Expr x, Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc index 61b6c9ce897f..345e2d0e60da 100644 --- a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc +++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc @@ -40,7 +40,7 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) { String mcpu = cfg.value()->mcpu; Array mattr = {cfg.value()->mattr}; - Bool debug_last_error = cfg.value()->debug_last_error; + runtime::Bool debug_last_error = cfg.value()->debug_last_error->value; Target cmsis_nn_target(TargetJSON{ {"kind", String("cmsis-nn")}, diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 10125bf814ad..00581a089a4a 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -37,7 +37,7 @@ using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc>("mattr") .add_attr_option("mcpu") - .add_attr_option("debug_last_error") + .add_attr_option("debug_last_error") .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 50c8b84a9069..ea040f6ff56a 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -39,32 +39,32 @@ namespace cutlass { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForCutlass()) // An integer specifying the compute capability. For example, 75 for Turing and // 80 or 86 for Ampere. - .add_attr_option("sm", Integer(80)) + .add_attr_option("sm", runtime::Int(80)) // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for // fp32 inputs on tensorcore. - .add_attr_option("use_3xtf32", Bool(true)) + .add_attr_option("use_3xtf32", runtime::Bool(true)) // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in // parallel across split-K blocks, and a separate global reduction kernel is launched to // accumulate partial reductions. The profiler will pick the best split-k factor from the // given candidate list. Note that the larger split-K factor requires a larger workspace. // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d // kinds, split_k_slices is ignored. - .add_attr_option>("split_k_slices", Array({1})) + .add_attr_option>("split_k_slices", Array{runtime::Int(1)}) // When True, profile all kernel variants with smaller alignments than the largest possible. - .add_attr_option("profile_all_alignments", Bool(false)) + .add_attr_option("profile_all_alignments", runtime::Bool(false)) // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel // is found. - .add_attr_option("find_first_valid", Bool(false)) + .add_attr_option("find_first_valid", runtime::Bool(false)) // Whether to compile profiler executables for different kernels in parallel. - .add_attr_option("use_multiprocessing", Bool(false)) + .add_attr_option("use_multiprocessing", runtime::Bool(false)) // Number of threads to use during compilation, or -1 to use number of cpus. - .add_attr_option("threads", Integer(-1)) + .add_attr_option("threads", runtime::Int(-1)) // Whether to replace sigmoid with tanh. - .add_attr_option("use_fast_math", Bool(false)) + .add_attr_option("use_fast_math", runtime::Bool(false)) // A temporary directory where intermediate compiled artifacts will be stored. .add_attr_option("tmp_dir", String("./tmp")); diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index a3f3e6e1eb6e..0f539d96e919 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -687,14 +687,14 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) { sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat, params->input_info.m_QuantizationInfo); params->split_info.m_Axis = attrs->axis; - if (attrs->indices_or_sections->IsInstance()) { - auto sections = Downcast(attrs->indices_or_sections)->value; + if (const auto* sections_ptr = attrs->indices_or_sections.as()) { + auto sections = sections_ptr->value; int size = input_tensor_shape[attrs->axis] / sections; for (int i = 0; i < sections; i++) { params->split_info.m_Sizes.push_back(size); } } else { - auto indices = Downcast>(attrs->indices_or_sections); + auto indices = Downcast>(attrs->indices_or_sections); int last_index = 0; for (const auto& i : indices) { params->split_info.m_Sizes.push_back(i->value - last_index); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 54d0595c4634..300372838416 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -307,8 +307,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { Array compile_artifacts; for (const auto& kv : mod->functions) { const tir::PrimFunc& prim_func = Downcast(kv.second); - Optional> params = - prim_func->GetAttr>("ethos-u.constants"); + auto params = prim_func->GetAttr>("ethos-u.constants"); ICHECK(params) << "microNPU params should be present"; auto primfunc_to_artifact_pf = tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 23a873b2d392..d87447f863e2 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -97,7 +97,7 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr CreateSplitReshapedTensors(const Expr& input, const Array& original_args) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; Array rets; int total_size = 0; @@ -132,7 +132,7 @@ class ExternalFuncIOHandler : public ExprRewriter { if (func->params.size() > 1) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; auto func_name = gv->name_hint; int total_size = 0; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index b45987f6be33..de9c81a2706e 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -38,6 +38,6 @@ TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) - .add_attr_option("example_attribute", Integer(0)); + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index f4babad50a3e..1dd5e3a4d772 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -177,12 +177,12 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector indices_or_sections; std::vector mode; std::vector axis = {std::to_string(split_attr->axis)}; - if (const auto* sections = split_attr->indices_or_sections.as()) { + if (const auto* sections = split_attr->indices_or_sections.as()) { mode.emplace_back("sections"); indices_or_sections.emplace_back(std::to_string(sections->value)); } else { mode.emplace_back("indices"); - auto indices = Downcast>(split_attr->indices_or_sections); + auto indices = Downcast>(split_attr->indices_or_sections); for (const auto& i : indices) { indices_or_sections.emplace_back(std::to_string(i->value)); } diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index 0277787a8c12..a62dc25e329c 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -38,30 +38,30 @@ namespace tensorrt { * - Runtime: src/runtime/contrib/tensorrt/... */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForTensorRT()) // A array of three integers given the major, minor, and patch numbers for the supported // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. - .add_attr_option>("tensorrt_version", Array()) + .add_attr_option>("tensorrt_version", Array()) // If true, the first tensor dimension for most operators is allowed to be Any and // TensorRT will assume it represents a batch dimension only known at inference time. // Fewer Relay operators are supported in implicit batch mode. Default true. - .add_attr_option("use_implicit_batch", Bool(true)) + .add_attr_option("use_implicit_batch", runtime::Bool(true)) // If true, excludes sub-graphs which do not have multiply-accumulate operations, even though // TensorRT supports them. ad. This is a simple heuristic to optimize the partitioning between // TensorRT and TVM. Not required if using Collage for partitioning. Defalut false. - .add_attr_option("remove_no_mac_subgraphs", Bool(false)) + .add_attr_option("remove_no_mac_subgraphs", runtime::Bool(false)) // How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. // Default 1G. - .add_attr_option("max_workspace_size", Integer(1 << 30)) + .add_attr_option("max_workspace_size", runtime::Int(1 << 30)) // If true, allows TensorRT to automatically convert float32 operations to float16. Must also be // enabled if any float16 operations are in the model. Note that TensorRT may still choose a // higher-precision kernel if it results in overall lower runtime, or if no low-precision // implementation exists. Default false. - .add_attr_option("use_fp16", Bool(false)) + .add_attr_option("use_fp16", runtime::Bool(false)) // If true, allows TensorRT to automatically convert float32 operations to uint8 // (aka quantized). Default false. - .add_attr_option("use_uint8", Bool(false)); + .add_attr_option("use_uint8", runtime::Bool(false)); } // namespace tensorrt } // namespace contrib diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc index 244f243749c1..0499c0bba198 100644 --- a/src/relay/backend/contrib/uma/targets.cc +++ b/src/relay/backend/contrib/uma/targets.cc @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") .add_attr_option("model") .add_attr_option>("libs") .add_attr_option("host") - .add_attr_option("from_device") + .add_attr_option("from_device") .set_attr( attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name)) .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); @@ -75,8 +75,9 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") } if (default_value->IsInstance()) { target_kind.add_attr_option(option_name, Downcast(default_value)); - } else if (default_value->IsInstance()) { - target_kind.add_attr_option(option_name, Downcast(default_value)); + } else if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, + Downcast(default_value)); } else { LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. " << "Given attribute option type: " << attr_option.second->GetTypeKey(); diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 1d6caecb87ba..66feac4699e6 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -89,13 +89,13 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { /********** Register Executors and options **********/ TVM_REGISTER_EXECUTOR("aot") - .add_attr_option("link-params", Bool(true)) - .add_attr_option("unpacked-api") + .add_attr_option("link-params", runtime::Bool(true)) + .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constant-byte-alignment"); + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constant-byte-alignment"); -TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", runtime::Bool(false)); /********** Registry **********/ diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 923c9b2d5f65..0534298ea44d 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 0c0ff7290115..3e86e1c8eaf9 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -73,6 +73,42 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp } bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + // Unwrapping arrays may find user-provided FFI types in the + // attributes (e.g. Defining pad_value as ((0,0), (0,0)) will result + // in runtime::Int. These need to be converted to compile-time IR + // types when encountered. + if (lhs->IsInstance() || + lhs->IsInstance() || + lhs->IsInstance()) { + TVMRetValue lhs_convert; + lhs_convert = lhs; + PrimExpr lhs_expr = lhs_convert; + return MatchRetValue(lhs_expr, rhs); + } + + // StructuralEqual doesn't check for conversions between FFI types + // and IR types, but the pattern-matcher should. Therefore, + // explicitly recurse into the array. + if (auto opt_lhs_array = lhs.as>()) { + if (Optional> opt_rhs_array = rhs) { + Array lhs_array = opt_lhs_array.value(); + Array rhs_array = opt_rhs_array.value(); + if (lhs_array.size() != rhs_array.size()) { + return false; + } + for (size_t i = 0; i < lhs_array.size(); i++) { + TVMRetValue rhs_item; + rhs_item = rhs_array[i]; + if (!MatchRetValue(lhs_array[i], rhs_item)) { + return false; + } + } + return true; + } else { + return false; + } + } + switch (rhs.type_code()) { case kDLInt: if (auto* val = lhs.as()) { diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 50d8531c7dd0..222aba4bd25b 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -79,7 +79,7 @@ Expr MakeReshape(Expr data, Array newshape, bool allowzero = false); Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end); -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); +Expr MakeSplit(Expr data, Variant> indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fde6daa4d851..96f833d80505 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2984,10 +2984,10 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, Layout ret = Layout::Undef(); size_t size = 0; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { size = sections->value; } else { - size = Downcast>(param->indices_or_sections).size() + 1; + size = Downcast>(param->indices_or_sections).size() + 1; } // If new_in_layouts are defined, this code tries to modify the layout. @@ -2998,13 +2998,12 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, param->axis = new_index; int factor = new_in_layouts[0].FactorOf(sp_dim); if (factor > 1) { - if (!param->indices_or_sections.as()) { - auto ios = Downcast>(param->indices_or_sections); - Array new_ios; + if (!param->indices_or_sections.as()) { + auto ios = Downcast>(param->indices_or_sections); + Array new_ios; for (const auto& v : ios) { - const IntImmNode* vint = v.as(); - new_ios.push_back(vint->value / factor); - if (vint->value % factor) { + new_ios.push_back(runtime::Int(v->value / factor)); + if (v->value % factor) { divisible = false; } } @@ -3041,7 +3040,7 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; ICHECK_GE(axis, 0) << "axis should be within the input dimension range."; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { if (!data->shape[axis].as()) { ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == tir::make_zero(DataType::Int(64)))) @@ -3061,8 +3060,8 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TupleType(Array(fields))); } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; @@ -3097,19 +3096,20 @@ Array SplitCompute(const Attrs& attrs, const Array& inpu const auto param = attrs.as(); ICHECK(param != nullptr); - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { +Expr MakeSplit(Expr data, Variant> indices_or_sections, + int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -3117,17 +3117,7 @@ Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = - MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } -}); +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body_typed(MakeSplit); RELAY_REGISTER_OP("split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. @@ -4157,11 +4147,13 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - attrs->exclusive = exclusive; + if (exclusive.defined()) { + attrs->exclusive = exclusive.value(); + } static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index a41e1e0d6674..74827f166b51 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -159,7 +159,7 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; - auto split = MakeSplit(data, Integer(branches.size()), 0); + auto split = MakeSplit(data, runtime::Int(branches.size()), 0); for (const auto& branch : branches) { auto split_data = TupleGetItem(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 34f986b251a2..df28506c6217 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -266,7 +266,7 @@ class ConstantFolder : public MixedModeMutator { // always use graph executor with no link-params dict.Set(tvm::attr::kExecutor, - relay::Executor::Create("graph", {{"link-params", Bool(false)}})); + relay::Executor::Create("graph", {{"link-params", runtime::Bool(false)}})); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index edf1e4c99f4d..da7a8f6420cd 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -36,8 +36,6 @@ namespace tvm { namespace relay { -using namespace tvm::runtime; - /*! What is automatic differentiation(AD) and why is it important? * By AD, we roughly mean, given a term which denotes some mathematical function, * derive a term which denotes the derivative of that mathematical function. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 5026b1bcba79..1112755b76a0 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array op_descriptor = + Array> op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc new file mode 100644 index 000000000000..9ab83a7b471c --- /dev/null +++ b/src/runtime/boxed_primitive.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/boxed_primitive.cc + * \brief Implementations of ObjectRef wrapper. + */ + +#include +#include + +namespace tvm { +namespace runtime { + +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); + +/* \brief Allow explicit construction of Box + * + * Convert a `bool` to `Box`. For use in FFI handling, to + * provide an umambiguous representation between `bool(true)` and + * `int(1)`. Will be automatically unboxed in the case where a + * `Box` is provided to a PackedFunc that requires `int` input, + * mimicking C++'s default conversions. + * + * This is only needed for Box, as Box and Box + * can be converted in C++ as part of `TVMArgValue::operator + * ObjectRef()` without ambiguity, postponing conversions until + * required. + */ +TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); + +/* \brief Return the underlying boolean object. + * + * Used while unboxing a boolean return value during FFI handling. + * The return type is intentionally `int` and not `bool`, to avoid + * recursive unwrapping of boolean values. + * + * This is only needed for Box, as Box and Box + * can be unambiguously unboxed as part of + * `TVMRetValue::operator=(ObjectRef)`. + */ +TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { + return obj->value; +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 57979b160ea7..04d36ad8bcab 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -361,14 +361,18 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r TVMAPISetLastError("ModuleGetFunction expects second argument to be a string"); return kTvmErrorFunctionCallWrongArgType; } - if (type_codes[2] != kDLInt) { + + if (type_codes[2] == kDLInt) { + query_imports = args[2].v_int64 != 0; + } else if (type_codes[2] == kTVMArgBool) { + query_imports = args[2].v_bool; + } else { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; } mod = (TVMModuleHandle)args[0].v_handle; name = args[1].v_str; - query_imports = args[2].v_int64 != 0; to_return = TVMModGetFunction(mod, name, query_imports, &ret_value->v_handle); if (to_return == 0) { diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 493bc3fb1dc9..f7204e372f6d 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -102,10 +102,10 @@ DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { int cnt = 0; for (int i = 3; i < num_args; ++i) { int type_code = type_codes[i]; - if (type_code != kDLInt && type_code != kDLUInt && type_code != kDLFloat && - type_code != kTVMDataType && type_code != kDLDevice && type_code != kTVMOpaqueHandle && - type_code != kTVMStr && type_code != kTVMNullptr && type_code != kTVMBytes && - type_code != kTVMObjectHandle) { + if (type_code != kDLInt && type_code != kDLUInt && type_code != kTVMArgBool && + type_code != kDLFloat && type_code != kTVMDataType && type_code != kDLDevice && + type_code != kTVMOpaqueHandle && type_code != kTVMStr && type_code != kTVMNullptr && + type_code != kTVMBytes && type_code != kTVMObjectHandle) { os << "\n Argument #" << i - 3 << " has unsupported type code: " << type_code << " (" << ArgTypeCode2Str(type_code) << ")"; cnt += 1; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index d08dadb02bb9..485ebdb449da 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -325,6 +325,10 @@ struct RPCReference { channel->template Write(value.v_int64); break; } + case kTVMArgBool: { + channel->template Write(value.v_bool); + break; + } case kTVMDataType: { channel->Write(value.v_type); // padding @@ -432,6 +436,10 @@ struct RPCReference { channel->template Read(&(value.v_int64)); break; } + case kTVMArgBool: { + channel->template Read(&(value.v_bool)); + break; + } case kTVMDataType: { channel->Read(&(value.v_type)); int32_t padding = 0; diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 2af31f1d4021..af1cf9d20335 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -279,7 +279,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional err_ctx) { - if (dtype.is_bool()) { + if (arg.IsObjectRef()) { + ObjectRef obj = arg.AsObjectRef(); + LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype + << ", but received ObjectRef of type " << obj->GetTypeKey(); + } else if (dtype.is_bool()) { arg.operator bool(); } else if (dtype.is_int()) { arg.operator int64_t(); @@ -426,7 +430,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device") * \return Bool */ bool ReadIfCond(TVMArgValue cond) { - if (cond.type_code() == kDLInt) return cond.operator bool(); + if (cond.type_code() == kDLInt || cond.type_code() == kTVMArgBool) { + return cond.operator bool(); + } NDArray arr = cond.operator tvm::runtime::NDArray(); if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 54194e7e2a41..61bdec680a29 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -323,12 +323,33 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } else if (const auto* float_imm = value.as()) { // TODO(yelite): Make float number printing roundtrippable - output_.precision(17); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { output_ << '"' << float_imm->value << '"'; + } else if (std::nearbyint(float_imm->value) == float_imm->value) { + // Special case for floating-point values which would be + // formatted using %g, are not displayed in scientific + // notation, and whose fractional part is zero. + // + // By default, using `operator<<(std::ostream&, double)` + // delegates to the %g printf formatter. This strips off any + // trailing zeros, and also strips the decimal point if no + // trailing zeros are found. When parsed in python, due to the + // missing decimal point, this would incorrectly convert a float + // to an integer. Providing the `std::showpoint` modifier + // instead delegates to the %#g printf formatter. On its own, + // this resolves the round-trip errors, but also prevents the + // trailing zeros from being stripped off. + std::showpoint(output_); + std::fixed(output_); + output_.precision(1); + output_ << float_imm->value; } else { + std::defaultfloat(output_); + std::noshowpoint(output_); + output_.precision(17); output_ << float_imm->value; } + } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index ef68b89b5bf4..686f486da6eb 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -30,6 +30,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return LiteralDoc::Str(s, p); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Bool obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Boolean(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Int obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Int(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Float obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Float(obj->value, p); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 6f9a8cbf8918..35a9f35db491 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -75,7 +75,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - return LiteralDoc::Int(n->value, n_p); + if (n->dtype.is_bool()) { + return LiteralDoc::Boolean(n->value, n_p); + } else { + return LiteralDoc::Int(n->value, n_p); + } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/support/array.h b/src/support/array.h index 0ca57a2410c5..0d4c8134787b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -164,12 +164,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -177,12 +179,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -191,11 +195,13 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : array) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -221,8 +227,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -233,8 +241,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int64_t x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -245,8 +255,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (double x : vec) { - result.push_back(FloatImm(tvm::DataType::Float(64), x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index aec57a1eb20d..928cdfcab80b 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,6 +189,58 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { + return arg; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") + .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") + .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { + return map[key]; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") + .set_body_typed([](Map map) -> ObjectRef { return map; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { + return expr; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") + .set_body_typed([](Array arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance()) + << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") + .set_body_typed([](Array> arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance() || item->IsInstance()) + << "Array contained " << item->GetTypeKey() + << " when it should contain either PrimExpr or PackedFunc"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") + .set_body_typed([](Map map) -> ObjectRef { + for (const auto& kv : map) { + ObjectRef value = kv.second; + CHECK(value->IsInstance()) + << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; + } + return map; + }); + /** * Simple event logger that can be used for testing purposes */ diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 481ba39cc7b1..21899a12c4b0 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -347,18 +347,26 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); - ICHECK(t.is_handle() || t.bits() == 64); - if (t.is_int()) { + if (t.is_bool()) { + // The stride between adjacent entries is still + // `sizeof(TVMValue)==64`, even if the enum currently holds a + // boolean. + buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_int64_, buf, index); + buf = builder_->CreatePointerCast(buf, DTypeToLLVMType(t)->getPointerTo()); + return TypedPointer(t_int8_, buf); + } else if (t.is_int() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float()) { + } else if (t.is_float() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else { - ICHECK(t.is_handle()); + } else if (t.is_handle()) { buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); + } else { + LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMValue"; } } default: @@ -1366,9 +1374,16 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref.addr, t_void_p_); - } else { - return builder_->CreateLoad(ref.type, ref.addr); } + + llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); + + if (op->dtype == DataType::Bool()) { + struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); + } + + return struct_value; + } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index dd5a3fb681ee..0406dcf951bb 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -294,10 +294,10 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.MCOptions.ABIName = Downcast(target.Get("mabi")); } - auto maybe_level = Downcast(target.Get("opt-level")); + auto maybe_level = target.Get("opt-level").as(); #if TVM_LLVM_VERSION <= 170 if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOpt::None; } else if (level == 1) { @@ -313,7 +313,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } #else if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOptLevel::None; } else if (level == 1) { @@ -333,8 +333,12 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // Fast math options - auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { - return Downcast(target.Get(flag.str()).value_or(Bool(false))); + auto GetBoolFlag = [&target](llvm::StringRef name) -> bool { + if (auto flag = target.Get(name.str())) { + return Downcast(flag); + } else { + return false; + } }; if (GetBoolFlag("fast-math")) { #if TVM_LLVM_VERSION >= 60 diff --git a/src/target/tag.cc b/src/target/tag.cc index 9eca3072df0e..d45bf61a38f1 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -76,61 +76,61 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}, + {"num-cores", runtime::Int(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}}}}); + {"num-cores", runtime::Int(4)}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(6)}}}}); + {"num-cores", runtime::Int(6)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(12)}}}}); + {"num-cores", runtime::Int(12)}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET @@ -139,10 +139,10 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"kind", String("cuda")}, \ {"keys", Array{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"max_threads_per_block", Integer(1024)}, \ - {"thread_warp_size", Integer(32)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"max_threads_per_block", runtime::Int(1024)}, \ + {"thread_warp_size", runtime::Int(32)}, \ + {"registers_per_block", runtime::Int(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -158,9 +158,9 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", runtime::Int(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(52428800)); + .with_config("l2_cache_size_bytes", runtime::Int(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -263,7 +263,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", runtime::Int(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); @@ -416,7 +416,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", Integer(Cores)}}); + {"num-cores", runtime::Int(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -432,9 +432,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"thread_warp_size", Integer(WarpSize)}, \ + {"max_threads_per_block", runtime::Int(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"thread_warp_size", runtime::Int(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index cd2e3714e422..a8337b58ae9b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,24 +359,31 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex() || + info.type_index == runtime::Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer or boolean std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Bool is a subclass of IntImm, so allow textual boolean values. + // Mimic C++ automatic conversions, allowing bool to be used for + // integer parameters. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); + throw Error(": Cannot parse integer from string: " + interp_str); } } - return Integer(v); + + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return runtime::Int(v); + } else { + return runtime::Bool(v); + } } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -410,13 +417,13 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "Integer")); - } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return GetRef(ObjTypeCheck(obj, "runtime.BoxInt")); + } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -483,7 +490,11 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { + if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -494,7 +505,7 @@ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { } return u; } - LOG(FATAL) << "Cannot stringify this object"; + LOG(FATAL) << "Cannot stringify object of type " << obj->GetTypeKey(); } std::string TargetInternal::StringifyArray(const ArrayNode& array) { @@ -953,7 +964,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")).IntValue(); + int device_id = Downcast(attrs.at("from_device"))->value; attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1006,38 +1017,13 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - switch (ret.type_code()) { - case kTVMNullptr: - // Nothing returned for this parameter, move on to the next one. - continue; - - case kTVMArgInt: - if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Integer(static_cast(ret)); - } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Bool(static_cast(ret)); - } else { - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received integer from device api"; - } - break; - - case kTVMStr: - ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) - << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received string from device api"; - output[key] = String(ret.operator std::string()); - break; - - default: - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; - break; + // Delegate conversion from TVMRetValue to the FFI's default conversions. + if (Optional opt = ret) { + output[key] = opt.value(); } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 708d3ccd7621..fced74c3a559 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -243,7 +243,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", Bool(true)}}; + Map features = {{"is_test", runtime::Bool(true)}}; target.Set("features", features); return target; } @@ -256,16 +256,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit @@ -273,7 +273,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -301,28 +301,29 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", runtime::Int(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + runtime::Int(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", Integer(1024)) - .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("max_num_threads", runtime::Int(1024)) + .add_attr_option("thread_warp_size", runtime::Int(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -332,24 +333,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(65536)) - .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(65536)) + .add_attr_option("thread_warp_size", runtime::Int(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(16384)) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("texture_spatial_limit", Integer(16384)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(16384)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("texture_spatial_limit", runtime::Int(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", Integer(128)) + .add_attr_option("max_function_args", runtime::Int(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -358,55 +359,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(32768)) - .add_attr_option("thread_warp_size", Integer(16)) - .add_attr_option("max_function_args", Integer(31)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(32768)) + .add_attr_option("thread_warp_size", runtime::Int(16)) + .add_attr_option("max_function_args", runtime::Int(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", runtime::Bool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", runtime::Bool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_num_threads", runtime::Int(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -423,8 +424,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 5797d2295bab..fb839c28da96 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -56,10 +56,25 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode* op); -inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); +static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { + const char* shared_text = + "When a TE compute node produces multiple outputs, " + "each of which is a reduction, " + "each reduction must be structurally identical, " + "except for the ReduceNode::value_index. "; + + StructuralEqual eq; + + ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " + << a->combiner << " does not match " << b->combiner; + ICHECK(a->source.same_as(b->source)) + << shared_text << "However, the input " << a->source << " does not match " << b->source; + ICHECK(eq(a->axis, b->axis)) << shared_text << "However, the reduction axis " << a->axis + << " does not match " << b->axis; + ICHECK(eq(a->condition, b->condition)) << shared_text << "However, the predicate " << a->condition + << " does not match " << b->condition; + ICHECK(eq(a->init, b->init)) << shared_text << "However, the initial value " << a->init + << " does not match " << b->init; } int ComputeOpNode::num_outputs() const { return body.size(); } @@ -529,8 +544,7 @@ class ComputeVerifier final : protected tir::ExprVisitor { << "with being Reduce operation or not."; if (reduce && reduce_) { - ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + AssertReduceEqual(reduce, reduce_); } level_ = 0; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 2eb0693685a6..b5a87d9446d8 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -355,11 +355,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Array seq_stmt; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - return a->combiner.same_as(b->combiner) && // - a->source.same_as(b->source) && // - a->axis.same_as(b->axis) && // - a->condition.same_as(b->condition) && // - ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); + StructuralEqual eq; + return eq(a->combiner, b->combiner) && // + eq(a->source, b->source) && // + eq(a->axis, b->axis) && // + eq(a->condition, b->condition) && // + eq(a->init, b->init); }; PrimExpr expr_body = compute_op->body[0]; @@ -370,7 +371,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in const tir::ReduceNode* reduce_ = compute_op->body[k].as(); ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should have the same attribute except value_index"; + << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " + << "but the first argument has body " << GetRef(reduce_) << ", while the " << k + << "-th argument has body " << GetRef(reduce); tensors.push_back(compute_op.output(k)); } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 4f5df7ad3024..774a0f8f1f89 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -63,7 +63,17 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") - .set_body_typed([](Array shape, DataType dtype, std::string name) { + .set_body_typed([](Variant> shape_arg, DataType dtype, + std::string name) { + auto shape = [&]() -> Array { + if (auto arg_expr = shape_arg.as()) { + return {arg_expr.value()}; + } else if (auto arg_array = shape_arg.as>()) { + return arg_array.value(); + } else { + LOG(FATAL) << "Variant did not contain either allowed type"; + } + }(); return placeholder(shape, dtype, name); }); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c38c5a5c800b..1ad8914e48cc 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -124,9 +124,10 @@ void ReplaceDataFlow(const Array& stages, std::unordered_mapcombiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); + StructuralEqual struct_equal; + return struct_equal(a->combiner, b->combiner) && struct_equal(a->source, b->source) && + struct_equal(a->axis, b->axis) && struct_equal(a->condition, b->condition) && + struct_equal(a->init, b->init); } Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 3a41c5ac5a25..70e82a605369 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -134,7 +134,7 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); if (target.defined() && target->kind->name == "hexagon") { - auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; + auto value = target->GetAttr("vtcm-capacity").value()->value; if (value > 0) return value; } return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1506082003fd..c38237a664f7 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -35,6 +35,18 @@ namespace tvm { namespace tir { +/* \brief Convert an object to a PrimExpr + * + * All conversions to a PrimExpr are performed as part of the FFI, + * when calling a function that accepts a PrimExpr as an argument. If + * a function must normalize to a PrimExpr (e.g. before accessing the + * `expr.dtype` field), this function allows the FFI conversions to be + * explicitly invoked. + */ +TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { + return expr; +}); + #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -546,7 +558,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || @@ -707,9 +721,11 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis if (!init.empty()) { ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; for (size_t i = 0; i < init.size(); i++) { + ICHECK(init[i].defined()) << "Init value must be defined"; ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || init[i]->IsInstance()) - << "init can only be a IntImm, FloatImm or ProducerLoad"; + << "init can only be a IntImm, FloatImm or ProducerLoad, " + << "but received " << init[i] << " of type " << init[i]->GetTypeKey(); } } n->dtype = source[value_index].dtype(); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 14dd0eadb65c..2c94b9d8646b 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -27,6 +27,8 @@ #include #include +#include "utils.h" + namespace tvm { namespace tir { namespace { @@ -79,6 +81,11 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, if (!ret_type.defined()) { ret_type = VoidType(); } + + if (attrs.defined()) { + attrs = Downcast(NormalizeAttributeObject(attrs)); + } + auto n = make_object(); n->params = std::move(params); n->body = std::move(body); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index b30d0caf6af3..78fb9365cc71 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map& param_map) { +PrimFunc Specialize(PrimFunc func, const Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 5df76450ff1e..9c8f580b5413 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -27,6 +27,7 @@ #include #include "buffer_common.h" +#include "utils.h" namespace tvm { namespace tir { @@ -61,6 +62,15 @@ TVM_REGISTER_NODE_TYPE(LetStmtNode); // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + // The nodes are not required to be a TIR type, and may legally + // contain any ObjectRef. However, normalizing to an IR type if + // possible prevents spurious discrepancies in StructuralEqual(). + if (auto opt = node.as()) { + node = Bool(opt.value()); + } else if (auto opt = node.as()) { + node = Integer(opt.value()); + } + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -109,13 +119,21 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding, Map annotations, Span span) { + ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); - ICHECK(min.dtype().is_scalar()); - ICHECK(extent.dtype().is_scalar()); - ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); + auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) { + auto dtype = expr.dtype(); + CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) + << "TIR For nodes require a scalar integer as the " << field_name << ", but received " + << expr << " with dtype " << dtype; + }; + require_scalar_int_dtype(loop_var, "loop_var"); + require_scalar_int_dtype(min, "min"); + require_scalar_int_dtype(extent, "extent"); + // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them // without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { @@ -136,6 +154,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); @@ -234,6 +254,8 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -288,6 +310,8 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -652,6 +676,8 @@ Block::Block(Array iter_vars, Array reads, Array init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc new file mode 100644 index 000000000000..0e3dc1237894 --- /dev/null +++ b/src/tir/ir/utils.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/utils.cc + * \brief Utilities for manipulating TIR + */ +#include "utils.h" + +#include + +namespace tvm { +namespace tir { + +ObjectRef NormalizeAttributeObject(ObjectRef obj) { + if (const auto* runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (const auto* runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (const auto* runtime_float = obj.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map(NormalizeAttributeObject); + } else if (auto opt_map = obj.as>()) { + Map new_map; + bool is_same = true; + + for (const auto& [key, obj] : opt_map.value()) { + ObjectRef new_obj = NormalizeAttributeObject(obj); + is_same = is_same && obj.same_as(new_obj); + new_map.Set(key, new_obj); + } + + if (is_same) { + return obj; + } else { + return new_map; + } + } else if (auto dict_attrs = obj.as()) { + auto new_attrs = Downcast>(NormalizeAttributeObject(dict_attrs->dict)); + if (new_attrs.same_as(dict_attrs->dict)) { + return GetRef(dict_attrs); + } else { + return DictAttrs(new_attrs); + } + } else { + return obj; + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h new file mode 100644 index 000000000000..b1f7a722899f --- /dev/null +++ b/src/tir/ir/utils.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/utils.h + * \brief Utilities for manipulating TIR + */ +#ifndef TVM_TIR_IR_UTILS_H_ +#define TVM_TIR_IR_UTILS_H_ + +#include + +namespace tvm { +namespace tir { + +/* \brief Normalize an ObjectRef held + * + * Where possible, the IR should be normalized contain IR types. For + * example, holding a `tir::IntImm` instead of a `runtime::Int`. In + * attributes, this is not always possible, as attributes may refer to + * non-IR objects. + * + * This function normalizes any `runtime::Int`, `runtime::Bool`, + * `runtime::Float`, or containers of those types to the corresponding + * IR type. + * + * \param obj The attribute object to be normalized + * + * \returns The normalized attribute + */ +ObjectRef NormalizeAttributeObject(ObjectRef obj); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_UTILS_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index c79a148e4b6e..dad4ea98d614 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -229,9 +229,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } PrimExpr ret(PrimExpr value, Span span) { + CHECK(value.defined()); return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } +TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; @@ -1048,12 +1051,15 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double(), args[2]); + if (auto opt = args[0].TryAsInt()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsBool()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsFloat()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); } else { - LOG(FATAL) << "only accept int or float"; // FIXME + LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " + << "but instead received argument with type code " << args[0].type_code(); // FIXME } }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index cda501cd992e..73b5ff3fafd4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -233,9 +233,9 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); @@ -914,6 +914,14 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ if (ann_val.as()) { return ann_val; } + if (auto* runtime_int = ann_val.as()) { + return IntImm(DataType::Int(32), runtime_int->value); + } else if (auto* runtime_float = ann_val.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto* runtime_bool = ann_val.as()) { + return Bool(runtime_bool->value); + } + if (const auto* expr = ann_val.as()) { ICHECK(!ann_val->IsInstance()) << "TypeError: runtime::String is expected, but gets StringImm"; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 4eccff10a2c7..092bcf0c79f9 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -87,8 +87,9 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) override; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 122c5ff0d9fe..9209e6578687 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -439,6 +439,11 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os } else if (const auto* float_imm = obj.as()) { os.precision(17); os << float_imm->value; + } else if (const auto* runtime_int = obj.as()) { + os << runtime_int->value; + } else if (const auto* runtime_float = obj.as()) { + os.precision(17); + os << runtime_float->value; } else if (const auto* array = obj.as()) { os << '['; bool is_first = true; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fe1c1850dcd5..fd1349e4a3ec 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,9 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision); + const Array& candidates, + const Array& probs, + Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 92c3423bcbbb..4c7b208e964f 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../ir/utils.h" #include "../utils.h" namespace tvm { @@ -97,6 +98,8 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { + ann_val = NormalizeAttributeObject(ann_val); + if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 2a2f17355ca6..8e16f50b8b95 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,19 +163,18 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const Array& candidates, const Array& probs, + Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { - const auto* int_imm = decision->as(); - i = int_imm->value; + i = decision->value()->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - std::vector weights = support::AsVector(probs); + std::vector weights = support::AsVector(probs); std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); @@ -183,8 +182,8 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } - *decision = Integer(i); // decision is guaranteed not to be nullptr. - return candidates[i].IntValue(); + *decision = runtime::Int(i); // decision is guaranteed not to be nullptr. + return candidates[i]->value; } std::function MakeMultinomialSampler( @@ -461,24 +460,11 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { - Array probs_float = probs.Map([](const ObjectRef& prob) { - const auto* prob_float = prob.as(); - if (prob_float != nullptr) { - return GetRef(prob_float); - } - const auto* prob_int = prob.as(); - if (prob_int != nullptr) { - return FloatImm(DataType::Float(32), static_cast(prob_int->value)); - } - LOG(FATAL) - << "SampleCategorical does not accept probability with type other than float or int."; - throw; - }); - return sch->SampleCategorical(candidates, probs_float, decision); + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 4b10df7e9728..6e243bf19198 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -112,7 +112,9 @@ Array TranslateInputRVs( } else if (const auto* str_obj = input.as()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); - } else if (input->IsInstance() || input->IsInstance()) { + } else if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { // Case 3. integer or floating-point number results.push_back(input); } else if (input->IsInstance()) { @@ -149,7 +151,9 @@ Array TranslateInputRVs(const Array& inputs, results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { // Case 3. integer or floating-point number - if (input->IsInstance() || input->IsInstance()) { + if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { results.push_back(input); continue; } @@ -388,9 +392,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { try { const ArrayNode* arr = decision_entry.as(); ICHECK(arr && arr->size() == 2); - const IntImmNode* arr0 = arr->at(0).as(); + auto arr0 = arr->at(0).as(); ICHECK(arr0); - index = arr0->value; + index = arr0.value(); decision = arr->at(1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 16c4350aaee6..1611109d7735 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 686d84ebc6fe..78629e84f039 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,8 +47,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) final; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index cc33ba9f86c2..14672f568549 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map param_map; + Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 423b0ca92237..2948773321dd 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,6 +155,7 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; + if (t.is_bool()) return DataType::Bool(); if (t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1a3888a7cd48..1cde4f2ebe7d 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -511,6 +511,8 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } else if (IsArrayHandle(arg)) { arg_tcode = kTVMDLTensorHandle; + } else if (arg.dtype().is_bool()) { + arg_tcode = kTVMArgBool; } // opaque handle need to set the kind properly if (arg_tcode == kTVMOpaqueHandle) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d327cdfa8393..9f2f1295fece 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -263,15 +263,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](DataType t, int i) { + auto f_arg_value = [&](DataType arg_type, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version - DataType api_type = APIType(t); + DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. - if (api_type != t) { - res = Cast(t, res); + if (api_type != arg_type) { + res = Cast(arg_type, res); } return res; }; @@ -319,10 +319,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { continue; } - var_def.emplace_back(f_arg_value(param.dtype(), i), param); - if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(param, func_ptr->buffer_map[param]); - } + PrimExpr arg_value; // type code checks Var tcode(param->name_hint + ".code", DataType::Int(32)); @@ -335,15 +332,45 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } else if (t.is_bool()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be boolean"; + seq_init.emplace_back( + AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgBool, + f_arg_value(DataType::Bool(), i), + cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), + }); + } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back( + AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgInt, + f_arg_value(t, i), + cast(t, f_arg_value(DataType::Bool(), i)), + }); } else { ICHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } + + var_def.emplace_back(arg_value, param); + if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc index 53ea7e39ed59..adabb9b9b6cf 100644 --- a/tests/cpp/relay/backend/runtime_test.cc +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -26,13 +26,13 @@ namespace tvm { namespace relay { TVM_REGISTER_RUNTIME("TestRuntime") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") .add_attr_option("another_option") - .add_attr_option("defaulty_the_default_option", Bool(false)); + .add_attr_option("defaulty_the_default_option", runtime::Bool(false)); TEST(Runtime, Create) { - Map attrs = {{"my_bool", Bool(true)}}; + Map attrs = {{"my_bool", runtime::Bool(true)}}; Runtime my_runtime = Runtime::Create("TestRuntime", attrs); ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); @@ -40,7 +40,7 @@ TEST(Runtime, Create) { } TEST(Runtime, UnknownAttr) { - Map attrs = {{"woofles", Bool(true)}}; + Map attrs = {{"woofles", runtime::Bool(true)}}; ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); } @@ -64,7 +64,7 @@ TEST(RuntimeRegistry, ListRuntimeOptions) { Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); ICHECK_EQ(attrs.empty(), false); - ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["my_bool"], "runtime.BoxBool"); ICHECK_EQ(attrs["your_names"], "Array"); ICHECK_EQ(attrs["another_option"], "runtime.String"); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2db4b572bf60..0a2b8206d322 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,15 +32,15 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { String mcpu = Downcast(target.at("mcpu")); target.Set("mcpu", String("super_") + mcpu); target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", Bool(true)}}); + target.Set("features", Map{{"test", runtime::Bool(true)}}); return target; } @@ -76,14 +76,14 @@ TEST(TargetKind, GetAttrMap) { TEST(TargetCreation, NestedConfig) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -91,13 +91,14 @@ TEST(TargetCreation, NestedConfig) { ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); ICHECK_EQ(target->tag, ""); ICHECK(target->keys.empty()); - Bool my_bool = target->GetAttr("my_bool").value(); + runtime::Bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool.operator bool(), true); Array your_names = target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = target->GetAttr>("her_maps").value(); + Map her_maps = + target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); @@ -105,15 +106,15 @@ TEST(TargetCreation, NestedConfig) { TEST(TargetCreationFail, UnrecognizedConfigOption) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -133,9 +134,9 @@ TEST(TargetCreationFail, TypeMismatch) { {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -150,13 +151,13 @@ TEST(TargetCreationFail, TypeMismatch) { TEST(TargetCreationFail, TargetKindNotFound) { Map config = { - {"my_bool", Bool("true")}, + {"my_bool", runtime::Bool("true")}, {"your_names", Array{"junru", "jian"}}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -178,15 +179,16 @@ TEST(TargetCreation, TargetParser) { TEST(TargetCreation, TargetFeatures) { Target test_target_with_parser("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); + ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); Target test_target_no_parser("TestTargetKind"); - ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); - ASSERT_EQ(test_target_no_parser->GetFeature("test", Bool(true)).value(), true); + ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); + ASSERT_EQ(test_target_no_parser->GetFeature("test", runtime::Bool(true)).value(), + true); } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", Bool(true)}}; + Map features = {{"test", runtime::Bool(true)}}; Map config = { {"kind", String("TestTargetParser")}, {"mcpu", String("woof")}, @@ -469,13 +471,13 @@ TEST(TargetCreation, DetectSystemTriple) { #endif TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) .set_attr(tvm::attr::kRelayToTIR, diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index bbfb8bd2db12..f5b1651e115a 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """Test packed function FFI.""" +import gc + +import numpy as np + import tvm from tvm import te import tvm.testing -import numpy as np +from tvm.script import tir as T def test_get_global(): @@ -37,7 +41,7 @@ def my_packed_func(*args): def test_get_callback_with_node(): - x = tvm.runtime.convert(10) + x = T.int32(10) def test(y): assert y.handle != x.handle @@ -66,7 +70,7 @@ def add(x): myf = tvm.runtime.convert(addy) f = myf(10) - assert f(11).value == 21 + assert f(11) == 21 def test_convert(): @@ -113,6 +117,14 @@ def test_device_func(dev): def test_rvalue_ref(): def callback(x, expected_count): + # The use count of TVM objects is decremented as part of + # `ObjectRef.__del__`, which runs when the Python object is + # destructed. However, Python object destruction is not + # deterministic, and even CPython's reference-counting is + # considered an implementation detail. Therefore, to ensure + # correct results from this test, `gc.collect()` must be + # explicitly called. + gc.collect() assert expected_count == tvm.testing.object_use_count(x) return x diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index afd716cde389..42f5b0ccd0b8 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -16,16 +16,27 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir +from tvm.script import tir as T class CanonicalChecker: def __init__(self): self.analyzer = tvm.arith.Analyzer() + def _convert(self, expr): + # TODO(Lunderberg): Make utility functions `tir.convert` and + # `relax.convert` that convert to their respective IR types. + # Implementation should be in C++, and should only consist of + # conversions that are applied automatically through FFI. + if isinstance(expr, int): + return T.int32(expr) + else: + return expr + def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - expected = tvm.runtime.convert(expected) + expected = self._convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected ) @@ -377,13 +388,13 @@ def test_simplify_normalize_min_value_expr(): x = te.var("x", "int32") ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) - ck.verify(te.min_value("int32") + x == 0, False) + ck.verify(te.min_value("int32") + x == 0, tir.const(False)) ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) - ck.verify(0 == te.min_value("int32") + x, False) + ck.verify(0 == te.min_value("int32") + x, tir.const(False)) ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) - ck.verify(x + te.min_value("int32") == 0, False) + ck.verify(x + te.min_value("int32") == 0, tir.const(False)) ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) - ck.verify(0 == x + te.min_value("int32"), False) + ck.verify(0 == x + te.min_value("int32"), tir.const(False)) def test_proddiv_simplify(): diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index 3a10ec05efeb..f0e6f05adfad 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm.tir import floordiv, floormod +from tvm.script import tir as T def ifuse(inputs, pred_extent=None): @@ -537,7 +538,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) - tvm.ir.assert_structural_equal(res[1][1], True) + tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) # compound 1 i0 = create_iter("i0", 4) @@ -553,7 +554,7 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) @@ -569,7 +570,7 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -587,11 +588,11 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) - tvm.ir.assert_structural_equal(res[2][1], True) + tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 @@ -606,9 +607,9 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - tvm.ir.assert_structural_equal(res[2][0], True) + tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -642,10 +643,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices @@ -661,9 +662,9 @@ def test_subspace_division(): assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map( @@ -690,10 +691,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) @@ -735,8 +736,8 @@ def test_subspace_divide_trivial_iters(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], x) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], y) diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py index d38fe70f6b5c..0aa353c60041 100644 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ b/tests/python/arith/test_arith_narrow_predicate_expression.py @@ -20,6 +20,7 @@ from tvm import tir from tvm.runtime import convert +from tvm.script import tir as T i = tir.Var("i", "int32") @@ -42,18 +43,18 @@ [i < n, i < 0], [i <= n, i <= 0], [i >= n, i >= 7], - [n > i, convert(0) > i], - [n < i, convert(7) < i], - [n <= i, convert(7) <= i], - [n >= i, convert(0) >= i], - [i == n, tir.all(i <= 0, convert(7) <= i)], - [n == i, tir.all(convert(7) <= i, i <= 0)], - [i != n, tir.any(i < 0, convert(7) < i)], - [n != i, tir.any(convert(7) < i, i < 0)], + [n > i, T.int32(0) > i], + [n < i, T.int32(7) < i], + [n <= i, T.int32(7) <= i], + [n >= i, T.int32(0) >= i], + [i == n, tir.all(i <= 0, T.int32(7) <= i)], + [n == i, tir.all(T.int32(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, T.int32(7) < i)], + [n != i, tir.any(T.int32(7) < i, i < 0)], [i // 4 > n, i // 4 > 7], - [n < i // 4, convert(7) < i // 4], + [n < i // 4, T.int32(7) < i // 4], [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, T.int32(0) <= tir.Add(i, 0) // 4)], [i + n < 10, i + 7 < 10], [i - n < 10, tir.Sub(i, 0) < 10], [tir.Not(i < n), tir.Not(i < 7)], diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 90f0aeef47d7..7fc1862192d6 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -27,6 +27,8 @@ from tvm.tir import truncdiv as tdiv from tvm.tir import truncmod as tmod +from tvm.script import tir as T + class TestCase: def __init__(self, before, expected, preconditions=None): @@ -35,10 +37,21 @@ def __init__(self, before, expected, preconditions=None): if isinstance(expected, tir.expr.EqualOp): expected = expected.asobject() - self.before = before - self.expected = expected + self.before = self._convert(before) + self.expected = self._convert(expected) self.preconditions = preconditions + @staticmethod + def _convert(expr): + if isinstance(expr, tir.expr.EqualOp): + return expr.asobject() + elif isinstance(expr, int): + return T.int32(expr) + elif isinstance(expr, float): + return T.float32(expr) + else: + return expr + @property def constraint(self): if self.preconditions is None: @@ -1008,8 +1021,8 @@ class TestComparisons(BaseCompare): TestCase(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), # Rewrite based on definition of integer division - TestCase(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), - TestCase(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5)), + TestCase(tir.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), + TestCase(tir.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), # Narrow upper bound using floormod TestCase(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), TestCase(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), @@ -1025,36 +1038,36 @@ class TestComparisons(BaseCompare): # Merge a known floordiv and an upper bound of floormod into a value range TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(50) <= x, x < 57), + tir.all(T.int32(50) <= x, x < 57), ), TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(50) <= x, x <= 57), + tir.all(T.int32(50) <= x, x <= 57), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(-50) <= x, x < -43), + tir.all(T.int32(-50) <= x, x < -43), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(-50) <= x, x <= -43), + tir.all(T.int32(-50) <= x, x <= -43), ), # Merge a known floordiv and an lower bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(57) < x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(57) < x, x < 60), ), TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(57) <= x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(57) <= x, x < 60), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(-43) < x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(-43) < x, x < -40), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(-43) <= x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(-43) <= x, x < -40), ), TestCase(tvm.te.min(x, 11) < 10, x < 10), TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")), @@ -1224,14 +1237,16 @@ class TestIfThenElse(BaseCompare): class TestCLZ(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), T.int32(32)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), T.int32(31)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), T.int32(30)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), T.int32(24)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), T.int32(64)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), T.int32(63)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), T.int32(62)), + TestCase( + tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), T.int32(56) + ), ) diff --git a/tests/python/arith/test_arith_solve_linear_equations.py b/tests/python/arith/test_arith_solve_linear_equations.py index 24eb860c55f6..3195a4ae514f 100644 --- a/tests/python/arith/test_arith_solve_linear_equations.py +++ b/tests/python/arith/test_arith_solve_linear_equations.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T def test_solution_consistency(): @@ -109,8 +110,8 @@ def test_unique_solution(): [x, y], ) assert list(solution.dst.variables) == [] - assert ir.structural_equal(solution.src_to_dst[x], 15) - assert ir.structural_equal(solution.src_to_dst[y], 5) + assert ir.structural_equal(solution.src_to_dst[x], T.int32(15)) + assert ir.structural_equal(solution.src_to_dst[y], T.int32(5)) def test_low_rank(): @@ -128,7 +129,7 @@ def test_low_rank(): [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) - assert ir.structural_equal(solution.src_to_dst[z], 5) + assert ir.structural_equal(solution.src_to_dst[z], T.int32(5)) def test_infer_range(): @@ -149,12 +150,12 @@ def test_infer_range(): assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range - assert ir.structural_equal(solution.dst.ranges[n0].min, -9) - assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) + assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9)) + assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) - assert ir.structural_equal(ineq.a, -5) + assert ir.structural_equal(ineq.a, T.int32(-5)) assert ir.structural_equal(ineq.b, n0) @@ -172,7 +173,7 @@ def test_ill_formed(): ) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py b/tests/python/arith/test_arith_solve_linear_inequality.py index 5285da12e75d..664258ae7cf1 100644 --- a/tests/python/arith/test_arith_solve_linear_inequality.py +++ b/tests/python/arith/test_arith_solve_linear_inequality.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11458") @@ -113,10 +114,10 @@ def test_dual_variable(): [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) - assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) + assert ir.structural_equal(solution.dst.ranges[x_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, T.int32(11)) + assert ir.structural_equal(solution.dst.ranges[y_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, T.int32(6)) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) @@ -185,7 +186,7 @@ def test_no_solution(): solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 112c521d06d4..112d1151febd 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -769,7 +769,7 @@ def check_cuda(dtype, n, l, padding, lanes): (n // lanes, l + 2 * padding, lanes), lambda i, j, k: tvm.te.if_then_else( tvm.te.any(j < padding, j >= l + padding), - tvm.runtime.convert(0).astype(dtype), + tvm.tir.const(0, dtype), A[i * lanes + k, j - padding], ), name="B", diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f50d63878e4f..d9a6fd6e62d1 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1138,5 +1138,46 @@ def func(): tvm.build(func) +def test_int_parameter(): + """Boolean may be passed to functions accepting int""" + + @T.prim_func + def func(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg > 0: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(True) + assert output == 10 + + output = built(False) + assert output == 20 + + +def test_bool_parameter(): + """Integers may be passed to functions accepting bool""" + + @T.prim_func + def func(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(1) + assert output == 10 + + output = built(2) + assert output == 10 + + output = built(0) + assert output == 20 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 61511c609ca4..238a77b4ef4b 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -56,20 +56,20 @@ def get_first_mismatch_ensure_symmetry(a, b): ( [1, 2, 3], [1, 4, 3], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], [10, 2, 30], - ObjectPath.root().array_index(0).attr("value"), - ObjectPath.root().array_index(0).attr("value"), + ObjectPath.root().array_index(0), + ObjectPath.root().array_index(0), ), ( [1, 3, 4], [1, 2, 3, 4], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], @@ -121,14 +121,28 @@ def test_shape_tuple_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None +@pytest.mark.parametrize( + "contents", + [ + {}, + {"a": 1, "b": 2}, + {"a": True, "b": False}, + ], +) +def test_string_map_structural_equal_to_self(contents): + a = tvm.runtime.convert({**contents}) + b = tvm.runtime.convert({**contents}) + assert get_first_mismatch_ensure_symmetry(a, b) is None + + @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ ( dict(a=3, b=4), dict(a=3, b=5), - ObjectPath.root().map_value("b").attr("value"), - ObjectPath.root().map_value("b").attr("value"), + ObjectPath.root().map_value("b"), + ObjectPath.root().map_value("b"), ), ( dict(a=3, b=4), diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index aa482dd65cd7..1e3249197851 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,16 +23,19 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1].value == 3 + assert a[-1] == 3 a_slice = a[-3:-1] - assert (a_slice[0].value, a_slice[1].value) == (1, 2) + assert (a_slice[0], a_slice[1]) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3]) + a = tvm.runtime.convert([1, 2, 3.5, True]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1].value == 2 + assert a_loaded[1] == 2 + assert a_loaded[2] == 3.5 + assert a_loaded[3] == True + assert isinstance(a_loaded[3], bool) def test_dir_array(): @@ -66,7 +69,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"].value == 2 + assert amap["a"] == 2 assert "a" in dd assert "b" in dd @@ -78,7 +81,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1].value for kv in amap.items()} + dd = {kv[0].name: kv[1] for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 2355aa19adec..b70406c1bb7a 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -16,6 +16,7 @@ # under the License. """Test type nodes in the IR""" import tvm +from tvm.script import tir as T def check_json_roundtrip(node): @@ -38,11 +39,9 @@ def test_tensor_type_bad_constructor(): def test_tensor_type(): - shape = tvm.runtime.convert([1, 2, 3]) - dtype = "float32" - tt = tvm.ir.TensorType(shape, dtype) - assert tt.dtype == dtype - assert tt.shape == shape + tt = tvm.ir.TensorType([1, 2, 3], "float32") + assert tt.dtype == "float32" + assert list(tt.shape) == [T.int32(1), T.int32(2), T.int32(3)] assert tt.span == None str(tt) check_json_roundtrip(tt) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index f1709c449d16..b0ddbe93601e 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" ) @@ -144,7 +144,7 @@ def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer(( vi, vj = T.axis.remap("SS", [i, j]) T.reads(x[vi, vj]) T.writes(y[vi, vj]) - y[vi, vj] = x[vi, vj] + T.float32(1) + y[vi, vj] = x[vi, vj] + T.float32(1.0) @R.function def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 97ad9f5dd034..64d5c7381171 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -404,7 +404,7 @@ def f( "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', "args": '[Var(name_hint="x"), Var(name_hint="y")]', "sinfo_args": "[ObjectStructInfo()]", - "attrs": '{"test_attr": 1}', + "attrs": '{"test_attr": True}', }, extern_call_text, ) diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 2ab5afaabf24..1efbd690f034 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -63,6 +63,13 @@ def foo(x: R.Tensor((2, 3), "float32", "llvm")): def test_dispatch_scanop_cuda(): + """R.cumsum and R.cumprod may be lowered with TOPI for GPU + + For the purpose of testing, this test case intentionally uses the + `exclusive=True` argument to prevent the `R.cumsum` from being + lowered to the packed func `"gpu_2d_continuous_cumsum"`. + """ + @I.ir_module class Before: I.module_global_infos({"vdevice": [I.vdevice("cuda", 0)]}) @@ -70,7 +77,7 @@ class Before: @R.function def main(x: R.Tensor(("m", 3), "float32", "cuda")): with R.dataflow(): - lv0 = R.cumsum(x, axis=1) + lv0 = R.cumsum(x, axis=1, exclusive=True) lv1 = R.cumprod(lv0, axis=1) gv = lv1 R.output(gv) @@ -89,6 +96,7 @@ def main(x: R.Tensor(("m", 3), "float32", "cuda")): topi.cuda.cumsum, x, axis=1, + exclusive=True, ) out = bb.emit_te( topi.cuda.cumprod, diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 7b64eb1dee39..e93547d83e3c 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -395,7 +395,7 @@ def test_call_tir_with_grad(): """ v0: R.Tensor((54, 96), dtype="float32") x = T.int64() -R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": T.float32(1), "x": x}) +R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) """, ) @@ -758,7 +758,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -770,7 +770,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index ab40e181a35a..30fd06d4f14d 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -566,7 +566,7 @@ def main(shape: R.Prim(value="n")): assert func(2) == 4 - with pytest.raises(tvm.TVMError): + with pytest.raises(TypeError): func(ShapeTuple([2])) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 9a4817f5fd8a..60f096585dfe 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -118,9 +118,10 @@ class Expected: @T.prim_func def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__ife"}) - if T.cast( - T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + if T.Call( "bool", + tvm.ir.Op.get("tir.tvm_call_packed"), + ["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))], ): T.anylist_setitem_call_packed( r, diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 4031790fc383..b79713e05ed3 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -18,6 +18,7 @@ import numpy as np import tvm +from tvm.script import tir as T from tvm import relay from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import * @@ -115,7 +116,7 @@ def test_DataTypePattern(): def test_ShapePattern(): - shape = [10, 10] + shape = [T.int32(10), T.int32(10)] pattern = has_shape(shape) assert isinstance(pattern, ShapePattern) tvm.ir.assert_structural_equal(pattern.shape, shape) diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py index d703ef1f3d9a..04662f21ae9e 100644 --- a/tests/python/relay/test_executor.py +++ b/tests/python/relay/test_executor.py @@ -57,7 +57,7 @@ def test_create_executor_attr_type_incorrect(): with pytest.raises( TVMError, match='Attribute "interface-api" should have type "runtime.String"' - ' but instead found "IntImm"', + ' but instead found "runtime.BoxBool"', ): Executor("aot", {"interface-api": True}) diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py index ea15dd0d3c88..db8252f3a3c4 100644 --- a/tests/python/relay/test_runtime.py +++ b/tests/python/relay/test_runtime.py @@ -51,7 +51,7 @@ def test_create_runtime_attr_not_found(): def test_create_runtime_attr_type_incorrect(): with pytest.raises( TVMError, - match='Attribute "system-lib" should have type "IntImm"' + match='Attribute "system-lib" should have type "runtime.BoxBool"' ' but instead found "runtime.String"', ): Runtime("crt", {"system-lib": "woof"}) @@ -65,7 +65,7 @@ def test_list_runtimes(): def test_list_runtime_options(runtime): aot_options = Runtime.list_registered_options(runtime) assert "system-lib" in aot_options - assert aot_options["system-lib"] == "IntImm" + assert aot_options["system-lib"] == "runtime.BoxBool" def test_list_runtime_options_not_found(): diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index f18994d52ce9..7d0cd51d3298 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -18,12 +18,13 @@ for expressions. """ import pytest +import numpy as np + import tvm -from tvm import IRModule, parser, relay, te -from tvm.relay import analysis, op, transform +from tvm import IRModule, relay +from tvm.relay import op, transform from tvm.relay.op import op as _op - -import numpy as np +from tvm.script import tir as T def infer_mod(mod, annotate_spans=True): @@ -554,40 +555,32 @@ def test_repeat_register(): assert "Operator custom_log3 is registered before" in str(cm.execption) -def test_argreduce_infer_return_type(): +@pytest.mark.parametrize("relay_op", [relay.op.argmax, relay.op.argmin]) +@pytest.mark.parametrize( + "shape_dtype", + [ + ("int32", T.int32), + ("int64", T.int64), + ], + ids=["int32", "int64"], +) +def test_argreduce_infer_return_type(relay_op, shape_dtype): x_shape = (1, 1) broadcast_shape = [1, 1] - shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] - - # Testing with argmax - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmax = relay.op.argmax(broadcast_to, axis=[1]) - - f = relay.Function([x], argmax) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) - - # Testing with argmin - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmin = relay.op.argmin(broadcast_to, axis=[1]) - - f = relay.Function([x], argmin) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) + (sdtype, conv) = shape_dtype + + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmax = relay_op(broadcast_to, axis=[1]) + + f = relay.Function([x], argmax) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 7538075ae7f8..e0d216b33e9a 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np +import pickle import random + +import numpy as np + import tvm import tvm.testing -import pickle -from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -96,8 +97,123 @@ def test_shape_tuple(): assert stuple == z +def test_bool_argument(): + """Boolean objects are currently stored as int""" + func = tvm.get_global_func("testing.AcceptsBool") + + assert isinstance(func(True), bool) + assert isinstance(func(1), bool) + assert isinstance(func(0), bool) + + +def test_int_argument(): + func = tvm.get_global_func("testing.AcceptsInt") + + assert isinstance(func(True), int) + assert isinstance(func(1), int) + assert isinstance(func(0), int) + + +def test_object_ref_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRef") + + assert isinstance(func(True), bool) + assert isinstance(func(1), int) + assert isinstance(func(3.5), float) + assert func(3.5) == 3.5 + + +def test_object_ref_array_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRefArray") + + assert isinstance(func([True, 17, "hello"]), bool) + assert isinstance(func([True]), bool) + assert isinstance(func([17]), int) + assert isinstance(func(["hello"]), str) + + +def test_map_argument_returns_value(): + func = tvm.get_global_func("testing.AcceptsMapReturnsValue") + + res = func({"a": 1, "b": 2}, "a") + assert isinstance(res, int) + assert res == 1 + + res = func({"a": True, "b": False}, "a") + assert isinstance(res, bool) + assert res == True + + +def test_map_argument_returns_map(): + func = tvm.get_global_func("testing.AcceptsMapReturnsMap") + + res = func({"a": 1, "b": 2}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, int) + + res = func({"a": False, "b": True}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, bool) + + +def test_conversion_of_arg(): + """Arguments may be converted + + The calling side of the FFI converts to types that are available + at runtime. However, there may be additional type conversions + required, that must be performed on the callee-side of the FFI. + """ + + func = tvm.get_global_func("testing.AcceptsPrimExpr") + + res = func(1) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "int32" + + res = func(True) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "bool" + + +def test_conversion_of_array_elements(): + """Elements of an array may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to array elements. Here, the Python-side of the FFI + converts the array `[1,2]` to `Array{runtime::Int(1), + runtime::Int(2)}`, and the C++ side of the FFI converts to + `Array{IntImm(1), IntImm(2)}`. + """ + + func = tvm.get_global_func("testing.AcceptsArrayOfPrimExpr") + + res = func([1, False]) + assert isinstance(res[0], tvm.tir.IntImm) + assert res[0].dtype == "int32" + assert isinstance(res[1], tvm.tir.IntImm) + assert res[1].dtype == "bool" + + +def test_conversion_of_map_values(): + """Elements of a map may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to map elements. Here, the Python-side of the FFI + converts the map `{'a':1, 'b':2}` to `Map{{"a", runtime::Int(1)}, + {"b", runtime::Int(2)}}`, and the C++ side of the FFI converts to + `Map{{"a", IntImm(1)}, {"b", IntImm(2)}}`. + """ + + func = tvm.get_global_func("testing.AcceptsMapOfPrimExpr") + + res = func({"a": 1, "b": False}) + assert isinstance(res["a"], tvm.tir.IntImm) + assert res["a"].dtype == "int32" + assert isinstance(res["b"], tvm.tir.IntImm) + assert res["b"].dtype == "bool" + + if __name__ == "__main__": - test_string() - test_adt_constructor() - test_tuple_object() - test_shape_tuple() + tvm.testing.main() diff --git a/tests/python/te/test_te_schedule_tensorize.py b/tests/python/te/test_te_schedule_tensorize.py index 79aecb78902a..419d3edb5c3d 100644 --- a/tests/python/te/test_te_schedule_tensorize.py +++ b/tests/python/te/test_te_schedule_tensorize.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T def intrin_vadd(xo, m, n): @@ -100,6 +101,7 @@ def add(m): def check(m, factor): x, y, z = add(m) + factor = T.int32(factor) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) vadd = intrin_vadd(xo, m, factor) @@ -133,7 +135,7 @@ def check_cache_write(m, factor): finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z_global], dom_map) # outer loop var will be rebased, so min value is the new loop var and extent is 1 - tvm.ir.assert_structural_equal(out_dom[xo].extent, 1) + tvm.ir.assert_structural_equal(out_dom[xo].extent, T.int32(1)) assert isinstance(out_dom[xo].min, tvm.tir.Var) assert xo.var.name == out_dom[xo].min.name @@ -183,7 +185,7 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -207,7 +209,7 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -230,7 +232,7 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -254,7 +256,7 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -264,10 +266,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) - check(16) - check_rfactor(16, 16) - check_rfactor_no_reset(16, 16) - check_rfactor_no_reset_multi_reduction(16, 16) + check(T.int32(16)) + check_rfactor(T.int32(16), T.int32(16)) + check_rfactor_no_reset(T.int32(16), T.int32(16)) + check_rfactor_no_reset_multi_reduction(T.int32(16), T.int32(16)) # This tests whether algorithm and intrinsics expressions are simplified diff --git a/tests/python/te/test_te_tag.py b/tests/python/te/test_te_tag.py index 6e88a12614cf..a4b76e7d6736 100644 --- a/tests/python/te/test_te_tag.py +++ b/tests/python/te/test_te_tag.py @@ -57,12 +57,12 @@ def test_with(): assert C.op.tag == "gemm" assert "hello" in C.op.attrs assert "xx" not in C.op.attrs - assert C.op.attrs["hello"].value == 1 + assert C.op.attrs["hello"] == 1 CC = tvm.ir.load_json(tvm.ir.save_json(C)) - assert CC.op.attrs["hello"].value == 1 - assert CC.op.attrs["arr"][0].value == 10 - # str format happened to be json compatible - assert json.loads(str(CC.op.attrs))["arr"][1] == 12 + assert CC.op.attrs["hello"] == 1 + assert len(CC.op.attrs["arr"]) == 2 + assert CC.op.attrs["arr"][0] == 10 + assert CC.op.attrs["arr"][1] == 12 def test_decorator(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index e94a4f09ec56..0e610cc1659b 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -122,7 +122,7 @@ def test_lower_build_tir_func(): def test_lower_build_tir_module(): func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", True) + func = func.with_attr("tir.noalias", T.bool(True)) ir_mod = IRModule({"main": func}) # check lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index b4b773197b14..d706e65d8186 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing from tvm import te from tvm.tir import Buffer +from tvm.script import tir as T + import numpy as np +import pytest def test_buffer(): @@ -78,9 +81,9 @@ def test_buffer_access_ptr_extent(): # Test extent from input params aptr = Ab.access_ptr("rw", extent=200) - tvm.ir.assert_structural_equal(aptr.args[3], 200) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(200)) aptr = Ab.access_ptr("rw", offset=100, extent=100) - tvm.ir.assert_structural_equal(aptr.args[3], 100) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(100)) def test_buffer_vload(): @@ -88,7 +91,7 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.ir.assert_structural_equal(load.indices, [2, 3]) + tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)]) def test_buffer_offset_of(): @@ -259,7 +262,7 @@ def test_buffer_flatten(): buf = tvm.tir.decl_buffer([16, 32]) flat = buf.get_flattened_buffer() assert buf.data.same_as(flat.data) - tvm.ir.assert_structural_equal(flat.shape, [16 * 32]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(16 * 32)]) def test_buffer_flatten_preserves_identity(): @@ -273,8 +276,8 @@ def test_buffer_flatten_uses_axis_separators(): """Flattening to N-d physical buffers uses the axis separators""" buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) flat = buf.get_flattened_buffer() - tvm.ir.assert_structural_equal(flat.axis_separators, [1]) - tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) + tvm.ir.assert_structural_equal(flat.axis_separators, [T.int32(1)]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(4 * 16), T.int32(32)]) def test_invalid_axis_separators_raises_exception(): diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tir-base/test_tir_index_map.py index e893ed897d65..3ddbd2f69f59 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tir-base/test_tir_index_map.py @@ -22,6 +22,7 @@ from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm.script import tir as T def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -37,28 +38,22 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_indices([0]), [0, 0]) - assert_structural_equal(index_map.map_indices([3]), [0, 3]) - assert_structural_equal(index_map.map_indices([4]), [1, 0]) - assert_structural_equal(index_map.map_indices([42]), [10, 2]) - assert_structural_equal( - index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")] - ) + assert_structural_equal(index_map.map_indices([0]), [T.int32(0), T.int32(0)]) + assert_structural_equal(index_map.map_indices([3]), [T.int32(0), T.int32(3)]) + assert_structural_equal(index_map.map_indices([4]), [T.int32(1), T.int32(0)]) + assert_structural_equal(index_map.map_indices([42]), [T.int32(10), T.int32(2)]) + assert_structural_equal(index_map.map_indices([T.int64(42)]), [T.int64(10), T.int64(2)]) def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_shape([4]), [1, 4]) - assert_structural_equal(index_map.map_shape([16]), [4, 4]) + assert_structural_equal(index_map.map_shape([4]), [T.int32(1), T.int32(4)]) + assert_structural_equal(index_map.map_shape([16]), [T.int32(4), T.int32(4)]) - assert_structural_equal(index_map.map_shape([14]), [4, 4]) - assert_structural_equal( - index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")] - ) - assert_structural_equal( - index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")] - ) + assert_structural_equal(index_map.map_shape([14]), [T.int32(4), T.int32(4)]) + assert_structural_equal(index_map.map_shape([T.int64(16)]), [T.int64(4), T.int64(4)]) + assert_structural_equal(index_map.map_shape([T.int64(14)]), [T.int64(4), T.int64(4)]) def test_inverse(): @@ -82,28 +77,28 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[16], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(False), ), "right_padding": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 0, j < 1), ), "left_and_right_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[14], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.Or( tvm.tir.And(i == 0, j < 1), tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), @@ -113,7 +108,7 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[dynamic_N], - post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, 4], + post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, T.int32(4)], padding=lambda i, j: tvm.tir.And( dynamic_N % (-4) != 0, tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), @@ -127,10 +122,10 @@ def test_nonbijective_inverse_gives_error(): ], pre_shape=[14, 31], post_shape=[ - 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 - 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 - 4, # Range of iter%4 - 8, # Range of iter%8 + T.int32(4), # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + T.int32(5), # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + T.int32(4), # Range of iter%4 + T.int32(8), # Range of iter%8 ], padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( tvm.tir.Or( @@ -147,35 +142,35 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 32, (i // 4) % 8, i % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_right_padding_transpose": dict( forward=lambda i: [(i // 4) % 8, i // 32, i % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_left_padding": dict( forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "multiple_left_padding_with_transpose": dict( forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "outer_loop_extent_one": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [i * 4 + j], pre_shape=[3], - post_shape=[1, 4], + post_shape=[T.int32(1), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(3) == j, ), } diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index eeedae1f127c..29efd95280be 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -32,7 +32,7 @@ def test_te_const(): assert isinstance(x, tvm.tir.IntImm) -def test_scalar_dtype_inference(): +def test_tir_const_dtype_inference(): for data in [ True, bool(1), @@ -49,28 +49,11 @@ def test_scalar_dtype_inference(): np.float64(1), ]: assert tvm.tir.const(data).dtype == str(np.array(data).dtype) + + assert tvm.tir.const(True).dtype == "bool" assert tvm.tir.const(1).dtype == "int32" assert tvm.tir.const(1.0).dtype == "float32" - for data in [ - True, - bool(1), - np.uint8(1), - np.uint16(1), - np.uint32(1), - np.uint64(1), - np.int8(1), - np.int16(1), - np.int32(1), - np.int64(1), - np.float16(1), - np.float32(1), - np.float64(1), - ]: - assert tvm.runtime.convert(data).dtype == str(np.array(data).dtype) - assert tvm.runtime.convert(1).dtype == "int32" - assert tvm.runtime.convert(1.0).dtype == "float32" - def test_make(): x = tvm.tir.const(1, "int32") @@ -133,7 +116,7 @@ def test_attr(): assert stmt.node == y a = tvm.runtime.convert(1) - assert a.value == 1 + assert a == 1 try: a.no_field assert False @@ -350,7 +333,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) - assert f2.attrs["calling_conv"].value == 1 + assert f2.attrs["calling_conv"] == 1 assert not func.attrs diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index c2f3f89e6e12..8ae576e9b922 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -146,7 +146,7 @@ def test_sample_categorical_serialize(): decisions.append(rv) new_sch = verify_trace_roundtrip(sch, mod=elementwise) for i, new_inst in enumerate(new_sch.trace.insts): - assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst]] def test_sample_perfect_tile_power_of_two(): diff --git a/tests/python/tir-schedule/test_tir_schedule_state.py b/tests/python/tir-schedule/test_tir_schedule_state.py index 74880e5a42d9..c023b9dbc59d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state.py +++ b/tests/python/tir-schedule/test_tir_schedule_state.py @@ -155,10 +155,10 @@ def test_replace_direct_write0(): old_hash = s.mod["main"].__hash__() sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) - # There is no other reference so the AST node can be written directly - assert old_hash == s.mod["main"].__hash__() # Check the replaced part is equal to the target tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) + # There is no other reference so the AST node can be written directly + assert old_hash == s.mod["main"].__hash__() # The target reuse the stmt of the sref, so the sref won't be None assert sref.stmt is not None diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py index d5d5e0634ef6..cb7151f875e3 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py @@ -1029,38 +1029,45 @@ class TestTileAwareCompaction(BaseCompactTest): # it is not an opaque block case intentionally is_lower_order_free = False - @T.prim_func - def before( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ): - for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - A_local = T.decl_buffer((26, 128), scope="local") - B_local = T.decl_buffer((128, 26), scope="local") - C_local = T.decl_buffer((26, 26), scope="local") - for ax0, ax1 in T.grid(26, 128): - if i_0 * 26 + ax0 < 128: - A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] - for ax0, ax1 in T.grid(128, 26): - if j_0 * 26 + ax1 < 128: - B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] - for i_1, j_1, k in T.grid(26, 26, 128): - if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: - if k == 0: - C_local[i_1, j_1] = T.float32(0) - C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] - for ax0, ax1 in T.grid(26, 26): - if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: - C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] - - # Get partitioned workload to compact - before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod) - before_mod = tvm.tir.transform.LoopPartition()(before_mod) - before = before_mod["main"] + @property + def before(self): + @T.prim_func + def main( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + A_local = T.decl_buffer((26, 128), scope="local") + B_local = T.decl_buffer((128, 26), scope="local") + C_local = T.decl_buffer((26, 26), scope="local") + for ax0, ax1 in T.grid(26, 128): + if i_0 * 26 + ax0 < 128: + A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] + for ax0, ax1 in T.grid(128, 26): + if j_0 * 26 + ax1 < 128: + B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] + for i_1, j_1, k in T.grid(26, 26, 128): + if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: + if k == 0: + C_local[i_1, j_1] = T.float32(0) + C_local[i_1, j_1] = ( + C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] + ) + for ax0, ax1 in T.grid(26, 26): + if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: + C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] + + # Get partitioned workload to compact + mod = tvm.IRModule.from_expr(main) + with tvm.transform.PassContext( + config={"tir.LoopPartition": {"partition_const_loop": True}} + ): + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + + return mod["main"] @T.prim_func def expected( diff --git a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py index 9f61b5a3920a..3078572bb508 100644 --- a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing -from tvm import te +from tvm import te, tir + +import pytest import numpy as np @@ -184,7 +186,7 @@ def collect_branch_stmt(x): if isinstance(x, tvm.tir.IfThenElse): branch_collector.append(x) - n = 21 + n = tir.const(21) A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 23a51a0817df..0b43db56f300 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -394,5 +394,144 @@ def func_without_arg( tvm.ir.assert_structural_equal(Expected, After) +def test_int_parameter(): + """Boolean may be passed to functions accepting int + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts an integer argument, the caller may call it with a boolean + value. + + This also provides backwards compatibility for functions that were + defined as accepting an integer, but are called with a boolean + argument. Prior to PackedFunc interface supporting boolean + arguments directly, the argument would be converted from boolean + to integer to be stored in a TVMValue. After adding support for + boolean arguments, this usage should not cause an error. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg > 0: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" + arg: T.int32 = T.if_then_else( + arg_code == 0, + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg > 0: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_bool_parameter(): + """An integer may be passed to a function acccepting Boolean + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts a boolean argument, the caller may call it with an integer + value. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" + arg: T.bool = T.if_then_else( + arg_code == 15, + T.tvm_struct_get(args, 0, 12, "bool"), + T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 4b71eb825414..68149e7d64bb 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -937,8 +937,8 @@ def test_vulkan_smem_reuse(): "kind": "vulkan", "max_num_threads": 256, "max_threads_per_block": 256, - "supports_float32": T.bool(True), - "supports_int32": T.bool(True), + "supports_float32": True, + "supports_int32": True, "tag": "", "thread_warp_size": 1, } diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index 279785fdca51..d8212d38854c 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -332,26 +332,35 @@ def convert_slice_to_bufferload() -> None: check_error(convert_slice_to_bufferload, 6) -def test_tvm_exception_catch(): +def test_tvm_exception_catch_from_special_stmt(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) + check_error(special_stmt_except, 2) + + +def test_tvm_exception_catch_from_scope_handler(): def scope_handler_except() -> None: for i in T.serial("1", "1"): # error T.evaluate(1) + check_error(scope_handler_except, 2) + + +def test_tvm_exception_catch_from_bare_intrin(): def intrin_except_unassign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.evaluate(A) # error + check_error(intrin_except_unassign, 3) + + +def test_tvm_exception_catch_from_assigned_intrin(): def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") A[0, 0] = A[A] # error - check_error(special_stmt_except, 2) - check_error(scope_handler_except, 2) - check_error(intrin_except_unassign, 3) check_error(intrin_except_assign, 3) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 8364e65a4178..b7ba57fa9387 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -230,7 +230,7 @@ def test_buffer_store(): obj, """ A = T.Buffer((128, 128), "float16") -A[128, 128] = A[128, 128] + T.float16(1) +A[128, 128] = A[128, 128] + T.float16(1.0) """, ) @@ -259,7 +259,7 @@ def test_let_stmt(): _assert_print( obj, """ -with T.LetStmt(T.float32(10)) as v: +with T.LetStmt(T.float32(10.0)) as v: T.evaluate(0) """, ) @@ -672,7 +672,7 @@ def test_call(): _assert_print( obj, """ -T.atan(T.float32(1)) +T.atan(T.float32(1.0)) """, ) @@ -682,7 +682,7 @@ def test_comm_reducer(): _assert_print( obj, """ -T.comm_reducer(lambda x, y: x + y, [T.float32(0)]) +T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]) """, ) @@ -712,7 +712,7 @@ def test_float_imm(): _assert_print( obj, """ -T.float16(1) +T.float16(1.0) """, ) @@ -942,7 +942,7 @@ def func(): @T.prim_func def func(): - T.evaluate(T.{dtype}(0)) + T.evaluate(T.{dtype}(0.0)) """ func = get_func(dtype) _assert_print(func, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f81a80de6d61..b44ff5ad7241 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) def block_elements(): @@ -3981,6 +3981,32 @@ def func() -> T.int32: return func +def func_attr_with_list(): + @T.prim_func + def func( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + D: T.Buffer((128, 128), "float32"), + ) -> None: + T.func_attr( + {"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [T.int32(1)]} + ) + C = T.alloc_buffer([128, 128], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("C"): + x, y, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C[x, y] = T.float32(0) + C[x, y] = C[x, y] + A[x, k] * B[y, k] + for i0, i1 in T.grid(128, 128): + with T.block("D"): + T.block_attr({"layout_free_placeholders": [C]}) + x, y = T.axis.remap("SS", [i0, i1]) + D[x, y] = C[x, y] + T.float32(1) + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4198,6 +4224,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero, return_zero_private, return_zero_private_with_attr, + func_attr_with_list, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9bc9800c1cb8..ae83a9d66392 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -19,6 +19,7 @@ import tvm from tvm import te from tvm.topi import utils +from tvm.script import tir as T from .environment import get_env @@ -1046,19 +1047,19 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) - tvm.ir.assert_structural_equal(src_coeff[-2], 1) - tvm.ir.assert_structural_equal(dst_coeff[-2], 1) + tvm.ir.assert_structural_equal(src_coeff[-2], T.int32(1)) + tvm.ir.assert_structural_equal(dst_coeff[-2], T.int32(1)) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - tvm.ir.assert_structural_equal(src_coeff[-3], env.BLOCK_OUT) - tvm.ir.assert_structural_equal(dst_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(src_coeff[-3], T.int32(env.BLOCK_OUT)) + tvm.ir.assert_structural_equal(dst_coeff[-3], T.int32(env.BLOCK_OUT)) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1] From 591cf1ec4281872b97449fdd0da56ff255c9f383 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Aug 2024 07:03:37 -0500 Subject: [PATCH 049/202] [Relax] Remove segfault in R.call_tir_inplace validation (#17242) Prior to this commit, the error message produced when validating `R.call_tir_inplace` included the shape of the argument that will be mutated in-place. This correctly caught and raised an error when the argument is a tensor with known shape that is incompatible with the output tensor's shape. However, this same error message could be also be reached if the input does not have `TensorStructInfo` at all, which would trigger a segfault. This commit updates the validation to print the argument's `StructInfo` directly, rather than a field from the struct info. This correctly raises an error for the cases where the argument is not a tensor, or is a tensor with unknown dimensionality, while still printing the explicit shape of the mismatched tensor when avalable. --- src/relax/op/op.cc | 80 ++++++----- tests/python/relax/test_transform.py | 197 ++++++++++++++++++++++----- 2 files changed, 202 insertions(+), 75 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 77cf4a2c6fd0..0a840248ffe8 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -419,13 +419,19 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // may result in an error if performed before normalization. call = Downcast(NormalizeCallTIR(ctx, std::move(call))); + Array sinfo_outputs = [&]() -> Array { + auto out_sinfo = call->sinfo_args[0]; + if (auto* tuple_output = out_sinfo.as()) { + return tuple_output->fields; + } else { + return {out_sinfo}; + } + }(); + // there must be an inplace index for each output const auto* attrs = call->attrs.as(); - size_t num_outputs = 1U; - if (auto* tup_info = call->sinfo_args[0].as()) { - num_outputs = tup_info->fields.size(); - } - if (attrs->inplace_indices.size() != num_outputs) { + ICHECK(attrs); + if (attrs->inplace_indices.size() != sinfo_outputs.size()) { ctx->ReportFatal(Diagnostic::Error(call) << "There must be an in-place index specified for each output"); } @@ -459,45 +465,37 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // input shape // TODO(@slyubomirsky): eventually we will want to handle cases where that is not true Tuple call_args = Downcast(call->args[1]); - if (attrs->inplace_indices.size() == 1) { - auto* out_sinfo = call->sinfo_args[0].as(); - if (!out_sinfo) { - ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); + + for (size_t i_output = 0; i_output < attrs->inplace_indices.size(); i_output++) { + auto i_input = attrs->inplace_indices[i_output].IntValue(); + if (i_input == -1) { + continue; } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[0].IntValue()]); - if (!input_sinfo || !input_sinfo->shape.defined() || - !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), - ctx->GetAnalyzer())) { + + auto sinfo_output = sinfo_outputs[i_output]; + auto tinfo_output = sinfo_output.as(); + + if (!tinfo_output || !tinfo_output->shape.defined() || tinfo_output->IsUnknownDtype()) { ctx->ReportFatal(Diagnostic::Error(call) - << "The shape of output 0 must match input " - << attrs->inplace_indices[0].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output 0 versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[0].IntValue()); + << "The output struct info for an in-place mutation must be a tensor " + << "with a defined shape and dtype, " + << "but output " << i_output << " has struct info " << sinfo_output); } - } else { - auto out_sinfos = call->sinfo_args[0].as()->fields; - for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { - if (attrs->inplace_indices[i].IntValue() == -1) { - continue; - } - auto* out_sinfo = out_sinfos[i].as(); - if (!out_sinfo) { - ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); - } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[i].IntValue()]); - if (!input_sinfo || !input_sinfo->shape.defined() || - !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), - ctx->GetAnalyzer())) { - ctx->ReportFatal(Diagnostic::Error(call) - << "The shape of output " << i << " must match that of input " - << attrs->inplace_indices[i].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output " << i << " versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[i].IntValue()); - } + + auto sinfo_input = GetStructInfo(call_args->fields[i_input]); + auto tinfo_input = sinfo_input.as(); + + if (!tinfo_input || + (tinfo_output->IsUnknownDtype() || tinfo_output->dtype != tinfo_input->dtype) || + (!tinfo_input->shape.defined() || + !CanProveShapeEqual(tinfo_input->shape.value(), tinfo_output->shape.value(), + ctx->GetAnalyzer()))) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The input used for an in-place mutation must be " + << "a tensor with identical shape and dtype as the output. " + << "However, output " << i_output << " with struct info " << sinfo_output + << " is specified as an in-place mutation of input " << i_input + << " with struct info " << sinfo_input); } } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e7e8f94fc2ac..ee2df866fb35 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -20,7 +20,7 @@ from tvm import relax import tvm.script -from tvm.script import tir as T, relax as R +from tvm.script import ir as I, tir as T, relax as R def test_to_non_dataflow(): @@ -446,45 +446,174 @@ def foo( tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True) -@pytest.mark.xfail() def test_call_tir_inplace_repeated_input(): - @tvm.script.ir_module - class Input: - @T.prim_func - def func( - A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") - ): - T.evaluate(0) + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class Input: + @T.prim_func + def func( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + C: T.Buffer((2, 3), "int32"), + ): + T.evaluate(0) - @R.function - def foo( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32") - ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): - R.func_attr({"relax.force_pure": True}) - gv0 = R.call_tir_inplace( - Input.func, - (x, y, z), - # repeated 0 -> that's an error - [0, 0], - [R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")], - ) - return gv0 + @R.function + def foo( + x: R.Tensor((2, 3), "int32"), + y: R.Tensor((2, 3), "int32"), + z: R.Tensor((2, 3), "int32"), + ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): + R.func_attr({"relax.force_pure": True}) + gv0 = R.call_tir_inplace( + Input.func, + (x, y, z), + # repeated 0 -> that's an error + [0, 0], + [R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")], + ) + return gv0 -@pytest.mark.xfail() def test_call_tir_inplace_all_new(): - @tvm.script.ir_module - class Input: - @T.prim_func - def func(A: T.Buffer((2, 3), "int32")): - T.evaluate(0) + with pytest.raises(tvm.error.DiagnosticError): - @R.function - def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): - R.func_attr({"relax.force_pure": True}) - # cannot make the only output a fresh one - gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32")) - return gv0 + @tvm.script.ir_module + class Input: + @T.prim_func + def func(A: T.Buffer((2, 3), "int32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + R.func_attr({"relax.force_pure": True}) + # cannot make the only output a fresh one + gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32")) + return gv0 + + +def test_inplace_mutation_with_tuple_argument_raises_error(): + """TIR PrimFuncs do not support Tuple arguments + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where each argument in the tuple may be expressed in + TIR. Here, `[[A]]` specifies a tuple of arguments, where the + first argument is itself a tuple. Since PrimFuncs do not support + Tuple arguments, this is invalid. + + This is a regression test. In previous implementations, this + triggered a segfault rather than raising an exception. + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): + cls = Module + gv1 = R.call_tir_inplace( + cls.multiply_by_two, + [[A]], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_non_tensor_argument_raises_error(): + """In-place argument must be a tensor + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where each argument in the tuple may be expressed in + TIR. Here, the argument `A` is not a tensor. + + This is a regression test. In previous implementations, this + triggered a segfault rather than raising an exception. + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Object): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_incompatible_tensor_shape_raises_error(): + """In-place argument must have compatible shape + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where the shape of each in-place argument is compatible + with the corresponding output. Here, the shape of argument `A` is + different than the output's shape (`[32]` as opposed to `[16]`). + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([32], dtype="float32")): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) + + +def test_inplace_mutation_with_incompatible_tensor_dtype_raises_error(): + """In-place argument must have compatible dtype + + The `R.call_tir_inplace` operator must receive an in-line tuple of + arguments, where the shape of each in-place argument is compatible + with the corresponding output. Here, the dtype of argument `A` is + different than the output's dtype (`int32` as opposed to `float32`). + + """ + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], dtype="int32")): + gv1 = R.call_tir_inplace( + Module.multiply_by_two, + [A], + out_sinfo=R.Tensor((16,), dtype="float32"), + inplace_indices=[0], + ) + return gv1 + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer((16,), "float32")): + for i in range(16): + A[i] = A[i] * T.float32(2) if __name__ == "__main__": From 05e2bc3340d1c0ca505e8a66bee29ffd5d294379 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 6 Aug 2024 07:13:49 -0700 Subject: [PATCH 050/202] [Relax] Implement R.ensure_zero_offset and update memory planning for R.view (#17145) Previously, `R.view` was legalized to extern call to `runtime.TVMArrayCreateView` during `LegalizeOps`. This call to extern func can't be properly handled by `StaticBlockPlanMemory` because it assumes the extern func does not retain the input buffer. Extern func returning a view of the input would break the ref count of the buffer. This PR defers the legalization of `R.view` so that it can be explicitly handled by memory planning. A new op `R.ensure_aligned` is added as discussed in #16955 --- include/tvm/relax/backend.h | 2 +- include/tvm/relax/op_attr_types.h | 9 +++ include/tvm/runtime/device_api.h | 5 ++ python/tvm/relax/op/memory/__init__.py | 2 +- python/tvm/relax/op/memory/view.py | 17 ++++++ python/tvm/relax/pipeline.py | 2 +- python/tvm/relax/transform/__init__.py | 9 +-- python/tvm/relax/transform/transform.py | 17 +++++- ...ltin_lower.cc => lower_runtime_builtin.cc} | 26 ++++++--- src/relax/op/memory/view.cc | 35 +++++++++++- src/relax/op/memory/view.h | 3 + .../transform/static_plan_block_memory.cc | 13 +++-- src/runtime/cpu_device_api.cc | 2 + src/runtime/cuda/cuda_device_api.cc | 2 + src/runtime/relax_vm/builtin.cc | 19 +++++++ tests/python/relax/test_op_view.py | 31 +++++----- ...test_transform_static_plan_block_memory.py | 57 ++++++++++++++++++- tests/python/relax/test_vm_builtin_lower.py | 4 +- 18 files changed, 211 insertions(+), 44 deletions(-) rename src/relax/backend/vm/{vm_builtin_lower.cc => lower_runtime_builtin.cc} (90%) diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 2fb11f5a6f83..e7d13c47b2bd 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -35,7 +35,7 @@ namespace transform { * * \return The Pass. */ -TVM_DLL Pass VMBuiltinLower(); +TVM_DLL Pass LowerRuntimeBuiltin(); /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index b44c4582d82d..291bee597c03 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc; +/*! \brief The function type of a function to lower the runtime builtin. + * + * A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`. + * + * \param bb The BlockBuilder context. + * \param call The call to be lowered. + */ +using FLowerBuiltin = runtime::TypedPackedFunc; + /*! * \brief Gradient for a specific op. * diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 14b2b84b0d36..c33606d98ed3 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI { return device_type != kDLCPU && device_type != kDLMicroDev; } + /*! + * \brief Whether pointer arithmetics on a device owned pointer may be performed on the host. + */ + virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; } + protected: /*! * \brief copy data from one place to another diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 422c5d2e1f53..1191550085de 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view +from .view import view, ensure_zero_offset diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 0c3d8a03b2dd..95adc782092f 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls): relative_byte_offset = _normalize(relative_byte_offset, PrimValue) return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore + + +def ensure_zero_offset(data: Expr) -> Expr: + """ + Ensure the tensor has elem_offset == 0. A copy will be made if necessary. + + Parameters + ---------- + data : relax.Expr + The input tensor + + Results + ------- + result : relax.Expr + The tensor with elem_offset == 0 + """ + return _ffi_api.ensure_zero_offset(data) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index d068f800d0e9..38242ff4d2d3 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I transform.RewriteCUDAGraph(), transform.LowerAllocTensor(), transform.KillAfterLastUse(), - transform.VMBuiltinLower(), + transform.LowerRuntimeBuiltin(), transform.ComputePrimValue(), transform.VMShapeLower(), transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5789e2fcf235..1ce864651cd9 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -55,6 +55,7 @@ LegalizeOps, LiftTransformParams, LowerAllocTensor, + LowerRuntimeBuiltin, MergeCompositeFunctions, MetaScheduleApplyDatabase, MetaScheduleTuneIRMod, @@ -64,8 +65,8 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, - RemoveUnusedParameters, RemoveUnusedOutputs, + RemoveUnusedParameters, ReorderPermuteDimsAfterConcat, ReorderTakeAfterMatmul, RewriteCUDAGraph, @@ -84,14 +85,14 @@ function_pass, ) +from .attach_external_modules import AttachExternModules +from .fast_math import FastMathTransform +from .fuse_transpose_matmul import FuseTransposeMatmul from .ipc_allreduce_rewrite import IPCAllReduceRewrite from .lazy_transform_params import LazyTransformParams from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage from .optimize_layout_transform import OptimizeLayoutTransform from .remove_redundant_reshape import RemoveRedundantReshape -from .fast_math import FastMathTransform -from .fuse_transpose_matmul import FuseTransposeMatmul -from .attach_external_modules import AttachExternModules # Import to register the legalization functions. from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3528b4429e6f..2546284625e9 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,6 +19,7 @@ import functools import inspect import types +import warnings from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np # type: ignore @@ -586,6 +587,16 @@ def ComputePrimValue() -> tvm.ir.transform.Pass: return _ffi_api.ComputePrimValue() # type: ignore +def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.LowerRuntimeBuiltin() # type: ignore + + def VMBuiltinLower() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. @@ -593,7 +604,11 @@ def VMBuiltinLower() -> tvm.ir.transform.Pass: ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.VMBuiltinLower() # type: ignore + warnings.warn( + "tvm.relax.transform.VMBuiltinLower has been renamed to 'LowerRuntimeBuiltin'. " + "This wrapper is for backwards compatibility, and will be removed in a later update." + ) + return _ffi_api.LowerRuntimeBuiltin() # type: ignore def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/lower_runtime_builtin.cc similarity index 90% rename from src/relax/backend/vm/vm_builtin_lower.cc rename to src/relax/backend/vm/lower_runtime_builtin.cc index 887998d004c7..a3867ae92448 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -17,13 +17,14 @@ * under the License. */ /*! - * \file src/relax/backend/vm/vm_builtin_lower.cc + * \file src/relax/backend/vm/lower_runtime_builtin.cc * \brief Lowers most builtin functions and packed calls. */ #include #include #include #include +#include #include #include #include @@ -33,11 +34,12 @@ namespace relax { // This pass lowers most ops to VM specific builtins. // TODO(relax-team): revisit after PrimValue. -class VMBuiltinLowerMutator : public ExprMutator { +class LowerRuntimeBuiltinMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { + static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); // post-order mutation Call call = Downcast(VisitExprPostOrder_(call_node)); @@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator { return MakeMemAllocTensor(call); } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); - } else { - return call; + } else if (const auto* op_node = call->op.as()) { + Op op = GetRef(op_node); + if (lower_builtin_fmap.count(op)) { + return lower_builtin_fmap[op](builder_, call); + } } + return call; } Expr MakeMemAllocStorage(const Call& call) { @@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; }; -Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } +Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); } namespace transform { -Pass VMBuiltinLower() { +Pass LowerRuntimeBuiltin() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; - return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LowerRuntimeBuiltin(f)); + }; + return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); +TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index e7634c7edfce..21a72f6200b0 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); -Expr LegalizeView(const BlockBuilder& bb, const Call& call) { +Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; Expr shape = call->args[1]; Expr dtype = call->args[2]; @@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view") "The view's byte offset, relative to the input tensor's byte offset.") .set_attr("RequiresArgumentShapes", Bool(false)) .set_attr("FInferStructInfo", InferStructInfoView) - .set_attr("FLegalize", LegalizeView) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinView); + +Expr ensure_zero_offset(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_zero_offset"); + return Call(op, {x}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); + +StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " should receive 1 argument, " + << "but received " << call->args); + } + return GetStructInfo(call->args[0]); +} + +Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { + const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; + return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)}); +} + +TVM_REGISTER_OP("relax.memory.ensure_zero_offset") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FInferStructInfo", InferStructInfoEnsureZeroOffset) + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinEnsureZeroOffset); } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index bc8002fa5b69..77ec7e9833cc 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -32,6 +32,9 @@ namespace relax { /*! \brief View a tensor with different properties. */ Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ +Expr ensure_aligned(const Expr& x); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2b16d8650906..74200526b699 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,8 +286,13 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is "relax.reshape". */ -bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } +/*! \brief Check if the input op is a memory op that may return the same buffer. */ +bool IsInplaceMemoryOp(const Expr& op) { + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& view_op = Op::Get("relax.memory.view"); + static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op); +} /*! \brief The base class for the storage allocation visitor. */ class StorageAllocatorBaseVisitor : public ExprVisitor { @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Create a storage token for builtin alloc_tensor. this->CreateToken(call); return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { // Reuse the input's token for builtin reshape. SetTokens(call, GetTokens(call->args[0])); return; @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { block_tokens.push_back(new_token.get()); } return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { Tokens tokens = GetTokens(call->args[0]); ICHECK(!tokens.IsNested()); if (tokens.IsLeaf()) { diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 774335f5660b..ccd726a6ece6 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI { void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CPUDeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 66357a191541..33908d750d6d 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI { CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data); } + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CUDADeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index af1cf9d20335..9fe6fba80f5c 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -551,6 +551,25 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); +TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { + if (data->byte_offset == 0) { + return data; + } + auto* device_api = DeviceAPI::Get(data->device); + if (device_api->SupportsDevicePointerArithmeticsOnHost() && + data->byte_offset % tvm::runtime::kAllocAlignment == 0) { + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + return NDArray::FromDLPack(dl_tensor); + } else { + auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + new_array.CopyFrom(data); + return new_array; + } +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 2433821c2abd..0900e1be306b 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -452,7 +452,9 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) -def test_legalize_without_any_changes_is_no_op(): +def test_legalize_is_no_op(): + """R.memory.view is not legalized until LowerRuntimeBuiltin""" + @I.ir_module class Before: @R.function @@ -460,18 +462,13 @@ def main(A: R.Tensor([4096], "float32")): B = R.memory.view(A) return B - @I.ir_module - class Expected: - @R.function - def main(A: R.Tensor([4096], "float32")): - B = A - return B + Expected = Before After = tvm.relax.transform.LegalizeOps()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_shape_change(): +def test_lower_runtime_builtin_shape_change(): @I.ir_module class Before: @R.function @@ -497,11 +494,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_shape_from_unknown(): +def test_lower_runtime_builtin_view_shape_from_unknown(): """R.memory.view does not require the input tensor to have a known shape""" @I.ir_module @@ -529,11 +526,11 @@ def main(A: R.Tensor(dtype="float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_dtype_change(): +def test_lower_runtime_builtin_dtype_change(): @I.ir_module class Before: @R.function @@ -559,11 +556,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_byte_offset(): +def test_lower_runtime_builtin_byte_offset(): @I.ir_module class Before: @R.function @@ -589,11 +586,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_with_multiple_updated_fields(): +def test_lower_runtime_builtin_view_with_multiple_updated_fields(): """R.memory.view may update more than one field in the view In this test case, a 4-kilobyte buffer is provided. The first @@ -650,7 +647,7 @@ def main(A: R.Tensor([4096], "uint8")): ) return (B, C) - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 63f422d4cfbe..f9e632d34897 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -185,7 +185,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) mod = relax.transform.LowerAllocTensor()(mod) mod = relax.transform.KillAfterLastUse()(mod) - mod = relax.transform.VMBuiltinLower()(mod) + mod = relax.transform.LowerRuntimeBuiltin()(mod) tvm.ir.assert_structural_equal(mod, ExpectedLowered) @@ -1449,5 +1449,60 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_view(): + @I.ir_module + class Before: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main(): + cls = Before + x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) + x1 = R.memory.view(x, [128], "float32", 0) + x2 = R.memory.ensure_zero_offset(x1) + y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(x2, y) + z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(y, z) + return z + + @I.ir_module + class Expected: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main() -> R.Tensor((128,), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32") + ) + x1: R.Tensor((128,), dtype="float32") = R.memory.view( + x, R.shape([128]), R.dtype("float32"), R.prim_value(0) + ) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_zero_offset(x1) + storage1: R.Object = R.memory.alloc_storage( + R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([128]), R.dtype("float32") + ) + cls.tir_exp(x2, y) + z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.tir_exp(y, z) + return z + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index df28db4d46d2..984f9f958ca2 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -57,7 +57,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: gv0 = alloc return gv0 - After = relax.transform.VMBuiltinLower()(Before) + After = relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) @@ -79,7 +79,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: return gv0 with pytest.raises(tvm.TVMError): - relax.transform.VMBuiltinLower()(Before) + relax.transform.LowerRuntimeBuiltin()(Before) if __name__ == "__main__": From 11be83262024fa73a36b744cfd2fc334d5b5e49d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 7 Aug 2024 12:19:13 -0400 Subject: [PATCH 051/202] Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool" (#17252) Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183)" This reverts commit 5f22be4d83ca698e316ac342f32f5b4d38155ca8. --- include/tvm/ir/attrs.h | 76 +- include/tvm/ir/expr.h | 130 +--- include/tvm/ir/transform.h | 34 +- include/tvm/meta_schedule/schedule_rule.h | 8 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/runtime/c_runtime_api.h | 5 +- .../tvm/runtime/container/boxed_primitive.h | 143 ---- include/tvm/runtime/container/variant.h | 2 +- include/tvm/runtime/ndarray.h | 2 - include/tvm/runtime/packed_func.h | 689 ++++-------------- include/tvm/target/target.h | 10 +- include/tvm/target/target_kind.h | 4 +- include/tvm/tir/expr.h | 57 -- include/tvm/tir/function.h | 2 +- include/tvm/tir/schedule/schedule.h | 5 +- python/tvm/_ffi/_ctypes/object.py | 22 - python/tvm/_ffi/_ctypes/packed_func.py | 7 +- python/tvm/_ffi/_ctypes/types.py | 3 - python/tvm/_ffi/_cython/base.pxi | 5 +- python/tvm/_ffi/_cython/object.pxi | 10 - python/tvm/_ffi/_cython/packed_func.pxi | 9 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/driver/tvmc/registry.py | 22 +- python/tvm/ir/attrs.py | 2 +- python/tvm/ir/expr.py | 5 +- python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/op/statistical.py | 22 +- python/tvm/relax/testing/ast_printer.py | 18 +- python/tvm/relax/training/setup_trainer.py | 4 +- python/tvm/relax/utils.py | 3 - .../relay/backend/contrib/ethosu/legalize.py | 2 +- python/tvm/relay/op/_tensor_grad.py | 3 - python/tvm/relay/op/_transform.py | 8 +- python/tvm/relay/op/contrib/ethosu.py | 4 +- python/tvm/relay/op/transform.py | 25 +- .../transform/fake_quantization_to_integer.py | 5 +- python/tvm/runtime/__init__.py | 4 +- python/tvm/runtime/container.py | 38 - python/tvm/runtime/object_generic.py | 75 +- python/tvm/script/parser/tir/parser.py | 2 - python/tvm/te/hybrid/calls.py | 12 +- python/tvm/te/hybrid/parser.py | 4 +- python/tvm/te/hybrid/utils.py | 28 +- python/tvm/te/operation.py | 1 + python/tvm/te/tensor.py | 11 +- python/tvm/tir/__init__.py | 1 - python/tvm/tir/expr.py | 4 - python/tvm/tir/ir_builder.py | 6 +- python/tvm/tir/op.py | 151 ++-- python/tvm/tir/schedule/trace.py | 15 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 2 +- python/tvm/topi/cuda/batch_matmul.py | 8 +- rust/tvm-rt/src/module.rs | 5 +- rust/tvm-sys/src/packed_func.rs | 35 +- src/auto_scheduler/compute_dag.cc | 16 +- .../search_policy/sketch_policy_rules.cc | 3 +- src/auto_scheduler/search_policy/utils.h | 12 +- .../msc/core/printer/msc_base_printer.cc | 9 - .../msc/core/printer/prototxt_printer.cc | 4 - src/contrib/msc/core/utils.cc | 4 - src/driver/driver_api.cc | 5 +- src/ir/attrs.cc | 89 --- src/ir/expr.cc | 17 +- src/ir/transform.cc | 41 +- src/meta_schedule/database/database_utils.cc | 10 +- src/meta_schedule/database/json_database.cc | 4 +- .../mutator/mutate_thread_binding.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 6 +- src/meta_schedule/mutator/mutate_unroll.cc | 6 +- .../schedule/cuda/thread_bind.cc | 6 +- .../schedule_rule/cross_thread_reduction.cc | 8 +- .../schedule_rule/multi_level_tiling.cc | 5 +- .../parallel_vectorize_unroll.cc | 6 +- .../schedule_rule/schedule_rule.cc | 12 +- src/meta_schedule/utils.h | 38 +- src/node/boxed_primitive.cc | 134 ---- src/node/script_printer.cc | 16 +- src/node/structural_equal.cc | 37 +- src/relax/backend/vm/codegen_vm.cc | 2 - src/relax/backend/vm/codegen_vm_tir.cc | 30 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/manipulate.cc | 6 +- src/relax/op/tensor/manipulate.h | 4 +- .../backend/contrib/cmsisnn/compiler_attrs.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- src/relay/backend/contrib/cutlass/target.cc | 18 +- .../backend/contrib/ethosn/ethosn_api.cc | 6 +- src/relay/backend/contrib/ethosu/codegen.cc | 3 +- .../backend/contrib/ethosu/preprocess.cc | 4 +- .../contrib/example_target_hooks/target.cc | 2 +- src/relay/backend/contrib/tensorrt/codegen.cc | 4 +- src/relay/backend/contrib/tensorrt/target.cc | 14 +- src/relay/backend/contrib/uma/targets.cc | 7 +- src/relay/backend/executor.cc | 10 +- src/relay/backend/runtime.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 36 - src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 48 +- .../transforms/combine_parallel_op_batch.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/higher_order_gradient.cc | 2 + src/relay/transforms/to_mixed_precision.cc | 4 +- src/runtime/boxed_primitive.cc | 65 -- src/runtime/crt/common/crt_runtime_api.c | 8 +- src/runtime/disco/bcast_session.cc | 8 +- src/runtime/minrpc/rpc_reference.h | 8 - src/runtime/relax_vm/builtin.cc | 10 +- .../printer/doc_printer/python_doc_printer.cc | 23 +- src/script/printer/ir/misc.cc | 15 - src/script/printer/relax/tir.cc | 6 +- src/support/array.h | 52 +- src/support/ffi_testing.cc | 52 -- src/target/llvm/codegen_cpu.cc | 29 +- src/target/llvm/llvm_instance.cc | 14 +- src/target/tag.cc | 66 +- src/target/target.cc | 66 +- src/target/target_kind.cc | 137 ++-- src/te/operation/compute_op.cc | 26 +- src/te/operation/create_primfunc.cc | 15 +- src/te/operation/placeholder_op.cc | 12 +- src/te/schedule/schedule_dataflow_rewrite.cc | 7 +- .../analysis/calculate_allocated_memory.cc | 2 +- src/tir/ir/expr.cc | 20 +- src/tir/ir/function.cc | 7 - src/tir/ir/specialize.cc | 2 +- src/tir/ir/stmt.cc | 32 +- src/tir/ir/utils.cc | 68 -- src/tir/ir/utils.h | 51 -- src/tir/op/op.cc | 16 +- src/tir/schedule/concrete_schedule.cc | 14 +- src/tir/schedule/concrete_schedule.h | 5 +- src/tir/schedule/instruction_traits.h | 5 - src/tir/schedule/primitive.h | 5 +- src/tir/schedule/primitive/annotate.cc | 3 - src/tir/schedule/primitive/sampling.cc | 36 +- src/tir/schedule/trace.cc | 12 +- src/tir/schedule/traced_schedule.cc | 6 +- src/tir/schedule/traced_schedule.h | 5 +- .../transforms/inline_private_functions.cc | 2 +- src/tir/transforms/ir_utils.h | 1 - src/tir/transforms/lower_tvm_builtin.cc | 2 - src/tir/transforms/make_packed_api.cc | 45 +- tests/cpp/relay/backend/runtime_test.cc | 10 +- tests/cpp/target_test.cc | 56 +- .../test_runtime_packed_func.py | 18 +- .../arith/test_arith_canonical_simplify.py | 23 +- .../arith/test_arith_iter_affine_map.py | 35 +- .../test_arith_narrow_predicate_expression.py | 21 +- .../arith/test_arith_rewrite_simplify.py | 63 +- .../test_arith_solve_linear_equations.py | 15 +- .../test_arith_solve_linear_inequality.py | 11 +- .../codegen/test_target_codegen_cuda.py | 2 +- .../codegen/test_target_codegen_llvm.py | 41 -- .../ir/test_container_structural_equal.py | 30 +- tests/python/ir/test_ir_container.py | 15 +- tests/python/ir/test_ir_type.py | 9 +- .../test_distributed_tvmscript_printer.py | 4 +- tests/python/relax/test_ast_printer.py | 2 +- .../relax/test_backend_dispatch_sort_scan.py | 10 +- .../relax/test_tvmscript_printer_relax.py | 6 +- tests/python/relax/test_vm_build.py | 2 +- tests/python/relax/test_vm_codegen_tir.py | 5 +- tests/python/relay/test_dataflow_pattern.py | 3 +- tests/python/relay/test_executor.py | 2 +- tests/python/relay/test_runtime.py | 4 +- tests/python/relay/test_type_infer.py | 65 +- .../python/runtime/test_runtime_container.py | 130 +--- tests/python/te/test_te_schedule_tensorize.py | 20 +- tests/python/te/test_te_tag.py | 10 +- tests/python/tir-base/test_lower_build.py | 2 +- tests/python/tir-base/test_tir_buffer.py | 17 +- tests/python/tir-base/test_tir_index_map.py | 55 +- tests/python/tir-base/test_tir_nodes.py | 27 +- .../test_tir_schedule_sampling.py | 2 +- .../tir-schedule/test_tir_schedule_state.py | 4 +- ...est_tir_transform_compact_buffer_region.py | 71 +- ...tir_transform_instrument_bound_checkers.py | 8 +- .../test_tir_transform_make_packed_api.py | 139 ---- .../test_tir_transform_storage_rewrite.py | 4 +- .../tvmscript/test_tvmscript_error_report.py | 17 +- .../tvmscript/test_tvmscript_printer_tir.py | 12 +- .../tvmscript/test_tvmscript_roundtrip.py | 31 +- vta/python/vta/transform.py | 13 +- 184 files changed, 1221 insertions(+), 3215 deletions(-) delete mode 100644 include/tvm/runtime/container/boxed_primitive.h delete mode 100644 src/node/boxed_primitive.cc delete mode 100644 src/runtime/boxed_primitive.cc delete mode 100644 src/tir/ir/utils.cc delete mode 100644 src/tir/ir/utils.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d038d5f59a5f..81611b1a535a 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -265,16 +265,7 @@ class DictAttrs : public Attrs { auto it = node->dict.find(attr_key); if (it != node->dict.end()) { - // For backwards compatibility, return through TVMRetValue. - // This triggers any automatic conversions registered with - // PackedFuncValueConverter. Importantly, this allows use of - // `GetAttr` and `GetAttr` for properties that - // are stored internally as `runtime::Box` and - // `runtime::Box`. - TVMRetValue ret; - ret = (*it).second; - Optional obj = ret; - return obj; + return Downcast>((*it).second); } else { return default_value; } @@ -324,46 +315,6 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } -/*! - * \brief Copy the DictAttrs, but overrides attributes with the - * entries from \p attrs. - * - * \param attrs The DictAttrs to update - * - * \param new_attrs Key/values attributes to add to \p attrs. - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); - -/*! - * \brief Copy the DictAttrs, but overrides a single attribute. - * - * \param attrs The DictAttrs to update - * - * \param key The update to insert or update. - * - * \param value The new value of the attribute - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value); - -inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) { - return WithAttr(std::move(attrs), String(key), std::move(value)); -} - -/*! - * \brief Copy the DictAttrs, but without a specific attribute. - * - * \param attrs The DictAttrs to update - * - * \param key The key to remove - * - * \returns The new DictAttrs with updated attributes. - */ -DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); - /*! * \brief Copy the function or module, but overrides * the attribute value key with the value. @@ -396,8 +347,12 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); - + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } return input; } @@ -416,9 +371,13 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - - node->attrs = WithAttrs(std::move(node->attrs), attrs); - + if (node->attrs.defined()) { + for (const auto& pair : attrs) { + node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); + } + } else { + node->attrs = DictAttrs(std::move(attrs)); + } return input; } @@ -453,9 +412,10 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = input.CopyOnWrite(); - node->attrs = WithoutAttr(std::move(node->attrs), attr_key); - + if (input->attrs.defined()) { + TNode* node = input.CopyOnWrite(); + node->attrs.CopyOnWrite()->dict.erase(attr_key); + } return input; } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index efde52385177..9b522389227a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -770,121 +770,53 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { - -// Automatic conversion into IntImm, Integer, and Bool, when called -// through the FFI. Automatic conversions into PrimExpr are -// registered in "tvm/tir/expr.h", as it includes conversions to the -// TIR-only StringImm. -// -// While the FFI only requires the From() method, these -// implementations also define a TryFrom() method to avoid duplicate -// logic in the PrimExpr conversion. - +// common rule for RetValue and ArgValue template <> -struct PackedFuncValueConverter { - template - static Optional TryFrom(const PODSubclass& val) { - if (auto opt = val.TryAsInt()) { - int64_t value = opt.value(); - auto dtype = - (value > std::numeric_limits::max() || value < std::numeric_limits::min()) - ? DataType::Int(64) - : DataType::Int(32); - return IntImm(dtype, value); - } else if (auto opt = val.TryAsBool()) { - return IntImm(DataType::Int(32), opt.value()); - } else { - return NullOpt; +struct PackedFuncValueConverter { + static PrimExpr From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return PrimExpr(ObjectPtr(nullptr)); } - } - - template - static tvm::IntImm From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); + if (val.type_code() == kDLInt) { + int64_t value = val.operator int64_t(); + if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { + return IntImm(runtime::DataType::Int(64), value); + } + return IntImm(runtime::DataType::Int(32), val.operator int()); } - } -}; - -template <> -struct PackedFuncValueConverter { - template - static tvm::Integer From(const PODSubclass& val) { - if (auto opt = PackedFuncValueConverter::TryFrom(val)) { - return Integer(opt.value()); - } else { - return val.template AsObjectRef(); + if (val.type_code() == kDLFloat) { + return FloatImm(runtime::DataType::Float(32), val.operator double()); } - } -}; -template <> -struct PackedFuncValueConverter { - template - static Optional TryFrom(const PODSubclass& val) { - if (auto opt = val.TryAsBool()) { - return tvm::Bool(opt.value()); - } else if (auto opt = val.TryAsInt()) { - int value = opt.value(); - ICHECK(value == 0 || value == 1) - << "ValueError: boolean value can only be 0 or 1, but get " << value; - return tvm::Bool(static_cast(value)); - } else { - return NullOpt; - } - } - - template - static tvm::Bool From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); - } + return PrimExpr::FromObject_(val.AsObjectRef()); } }; template <> -struct PackedFuncValueConverter { - static Optional TryFrom(const TVMPODValue_& val) { - if (auto opt = val.TryAsFloat()) { - return FloatImm(runtime::DataType::Float(32), opt.value()); - } else { - return NullOpt; +struct PackedFuncValueConverter { + static tvm::Integer From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Integer(ObjectPtr(nullptr)); } - } - - template - static tvm::FloatImm From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); + if (val.type_code() == kTVMArgInt) { + return Integer(val.operator int()); } + return val.AsObjectRef(); } }; -/* \brief Backwards compatibility wrapper for IntImm arguments - * - * In previous versions of TVM, IntImm was the default FFI type for - * integer arguments, instead of runtime::Int. For backwards - * compatibility where the callee has been updated to expected a - * runtime::Int, the caller has not been updated to provide a - * runtime::Int (e.g. relay script parsing), and the auto-unboxing of - * runtime::Int does not apply (e.g. making an `Array`), - * allow the IntImm to be generated. - */ template <> -struct PackedFuncValueConverter { - template - static runtime::Int From(const PODSubclass& val) { - if (val.template IsObjectRef()) { - return runtime::Int(val.template AsObjectRef()->value); - } else { - return val.template AsObjectRef(); +struct PackedFuncValueConverter { + static tvm::Bool From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Bool(ObjectPtr(nullptr)); + } + if (val.type_code() == kTVMArgInt) { + int v = val.operator int(); + ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; + return Bool(static_cast(v)); } + return val.AsObjectRef(); } }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 5828d98206ad..adf332525020 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -271,36 +271,7 @@ class PassContext : public ObjectRef { using ValueNodeType = typename ValueType::ContainerType; // NOTE: we could further update the function later. uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); - auto type_key = runtime::Object::TypeIndex2Key(tindex); - - auto* reflection = ReflectionVTable::Global(); - - auto legalization = [=](ObjectRef obj) -> ObjectRef { - if (obj->IsInstance::ContainerType>()) { - return reflection->CreateObject(type_key, Downcast>(obj)); - } else { - // Backwards compatibility for config options defined prior to - // https://github.com/apache/tvm/pull/16183. This commit - // changed the default FFI conversion of python integers from - // `tvm::IntImm` to `runtime::Int`. - // - // This backwards compatibility fix can be removed when all - // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are - // updated to use `runtime::Int` and `runtime::Bool`. - TVMRetValue ret; - ret = obj; - try { - ValueType legalized = ret; - return legalized; - } catch (Error& err) { - LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key - << ", but received error when converting to this type.\n" - << err.what(); - } - } - }; - - RegisterConfigOption(key, tindex, legalization); + RegisterConfigOption(key, tindex); return tindex; } @@ -314,8 +285,7 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index, - std::function legalization); + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); // Classes to get the Python `with` like syntax. friend class Internal; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 90aec05187eb..d91812fb55cb 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 91020fc7443b..249b9cd0e50d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - Variant> indices_or_sections; + ObjectRef indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b4c653a0a59e..f1046ef24266 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -81,7 +81,6 @@ #ifdef __cplusplus extern "C" { #endif -#include #include #include @@ -187,12 +186,11 @@ typedef enum { kTVMBytes = 12U, kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, - kTVMArgBool = 15U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 16U, + kTVMExtBegin = 15U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, // The following section of code is used for non-reserved types. @@ -209,7 +207,6 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; - bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h deleted file mode 100644 index 8d01b5dc17b5..000000000000 --- a/include/tvm/runtime/container/boxed_primitive.h +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/boxed_primitive.h - * \brief Runtime container types for primitives stored as ObjectRef. - */ -#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ -#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ - -#include -#include - -namespace tvm { -namespace runtime { - -namespace detail { -/* \brief Provide the BoxNode type traits in templated contexts - * - * The Box class is used in many templated contexts, and is easier - * to have templated over the primitive type. - * - * However, much of the TVM type system depends on classes having a - * unique name. For example, the use of `Object::IsInstance` depends - * on `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will - * result in duplicate indices, and invalid downcasting. Furthermore, - * the name must be specified in the Python FFI using - * `tvm._ffi.register_object`. This prevents use of - * `typeid(T)::name()` to build a unique name, as the name is not - * required to be human-readable or consistent across compilers. - * - * This utility struct should be specialized over the primitive type - * held by the box, to allow explicit listing of the `_type_key` and - * other similar tratis. - * - * Note: This should only contain traits that are required at runtime, - * and should *not* contain extensions for features that are only - * available at compile-time. For integration with compile-time-only - * functionality (e.g. StructuralHash, StructuralEqual), see - * `BoxNodeCompileTimeTraits` in `src/node/boxed_primitive.cc`. - */ -template -struct BoxNodeRuntimeTraits; - -} // namespace detail - -template -class BoxNode : public Object { - public: - /*! \brief Constructor - * - * \param value The value to be boxed - */ - explicit BoxNode(Prim value) : value(value) {} - - /*! \brief The boxed value */ - Prim value; - - static constexpr const char* _type_key = detail::BoxNodeRuntimeTraits::_type_key; - static constexpr bool _type_has_method_visit_attrs = false; - TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); -}; - -template -class Box : public ObjectRef { - public: - /*! \brief Constructor - * - * \param value The value to be boxed - */ - Box(Prim value) : ObjectRef(make_object>(value)) {} // NOLINT(*) - - operator Prim() const { return (*this)->value; } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); -}; - -/*! \brief Boxed version of C++ int64_t - * - * Can be used to store POD integer values as a TVM ObjectRef. Used - * for FFI handling, and for storing POD types inside TVM containers. - */ -using Int = Box; - -/*! \brief Boxed version of C++ double - * - * Can be used to store POD floating-point values as a TVM ObjectRef. - * Used for FFI handling, and for storing POD types inside TVM - * containers. - */ -using Float = Box; - -/*! \brief Boxed version of C++ bool - * - * Can be used to store POD boolean values as a TVM ObjectRef. Used - * for FFI handling, and for storing POD types inside TVM containers. - * - * When passing from Python to C++, TVM PackedFunc conversion follow - * C++ conversion rules, and allow bool->int and int->bool - * conversions. When passing from C++ to Python, the types are - * returned as bool or int. If the C++ function uses ObjectRef to - * hold the object, a Python to C++ to Python round trip will preserve - * the distinction between bool and int. - */ -using Bool = Box; - -namespace detail { -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxInt"; -}; - -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxFloat"; -}; - -template <> -struct BoxNodeRuntimeTraits { - static constexpr const char* _type_key = "runtime.BoxBool"; -}; -} // namespace detail - -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index e8defa4e6fee..7953ac47c1cf 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_base_of_v || ...); + static constexpr bool is_variant = (std::is_same_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index fef61a753103..3eb225fccffe 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -226,8 +226,6 @@ class NDArray : public ObjectRef { protected: friend class TVMPODValue_; - template - friend class TVMPODValue_CRTP_; friend class TVMRetValue; friend class TVMArgsSetter; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 98196c13af7f..7266f8c4a50a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,7 +26,6 @@ #include #include -#include #include #include #include @@ -38,7 +37,6 @@ #include #include #include -#include #include #include #include @@ -431,11 +429,9 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) -#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ - "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) - // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) +#define TVM_CHECK_TYPE_CODE(CODE, T) \ + ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -514,7 +510,6 @@ struct ObjectTypeChecker> { } static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; - template struct ObjectTypeChecker> { static Optional CheckAndGetMismatch(const Object* ptr) { @@ -550,43 +545,40 @@ struct ObjectTypeChecker> { } }; -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - return ObjectTypeChecker::CheckAndGetMismatch(ptr); - } - static bool Check(const Object* ptr) { return ObjectTypeChecker::Check(ptr); } - static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } - static std::string VariantNames() { return ObjectTypeChecker::TypeName(); } -}; - -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - auto try_first = ObjectTypeChecker::CheckAndGetMismatch(ptr); - if (!try_first.defined()) { - return try_first; - } - - return ObjectTypeChecker>::CheckAndGetMismatch(ptr); - } - static bool Check(const Object* ptr) { - return ObjectTypeChecker::Check(ptr) || - ObjectTypeChecker>::Check(ptr); - } - static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } - static std::string VariantNames() { - return ObjectTypeChecker::TypeName() + ", " + - ObjectTypeChecker>::VariantNames(); - } -}; - /*! * \brief Internal base class to * handle conversion to POD values. */ class TVMPODValue_ { public: + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (type_code_ == kDLInt) { + return static_cast(value_.v_int64); + } + TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); + return value_.v_float64; + } + operator int64_t() const { + TVM_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64; + } + operator uint64_t() const { + TVM_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64; + } + operator int() const { + TVM_CHECK_TYPE_CODE(type_code_, kDLInt); + ICHECK_LE(value_.v_int64, std::numeric_limits::max()); + ICHECK_GE(value_.v_int64, std::numeric_limits::min()); + return static_cast(value_.v_int64); + } + operator bool() const { + TVM_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64 != 0; + } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; @@ -636,39 +628,12 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } - - std::optional TryAsBool() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kTVMArgBool) { - return value_.v_bool; - } else { - return std::nullopt; - } - } - - std::optional TryAsInt() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kDLInt) { - return value_.v_int64; - } else { - return std::nullopt; - } - } - - std::optional TryAsFloat() const { - // Helper function to reduce duplication in the variable integer - // conversions. This is publicly exposed, as it can be useful in - // specializations of PackedFuncValueConverter. - if (type_code_ == kDLFloat) { - return value_.v_float64; - } else { - return std::nullopt; - } - } + // ObjectRef handling + template ::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; protected: friend class TVMArgsSetter; @@ -683,90 +648,13 @@ class TVMPODValue_ { int type_code_; }; -/*! \brief A utility class that adds methods useful for each POD type - * - * These cannot be provided in the base PODValue_ class, because - * TVMArgValue and TVMRetValue have different semantics for kTVMStr - * and kTVMBytes. - * - * kTVMStr: - * - * For `TVMArgValue`, the active variant is `v_str`, a `const - * char*`. For `TVMRetValue`, the active variant is `v_handle`, - * and should be cast from `void*` to `std::string*`. - * - * kTVMBytes: - * - * The active variant is `v_handle`, a `void*`. For - * `TVMArgValue`, should be cast to `TVMByteArray*`. For - * `TVMRetValue`, should be cast to `std::string*`. - * - * When converting into an `ObjectRef`, a string may be used to build - * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use - * different representations for strings, any utility funciton which - * might attempt a conversion to an `ObjectRef` must be performed - * within a context that is aware of the derived class. - */ -template -class TVMPODValue_CRTP_ : public TVMPODValue_ { - public: - using TVMPODValue_::TVMPODValue_; - - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; - - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (auto opt = TryAsFloat()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); - } - } - operator int64_t() const { - if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } - operator uint64_t() const { return operator int64_t(); } - operator int() const { - int64_t value = operator int64_t(); - ICHECK_LE(value, std::numeric_limits::max()); - ICHECK_GE(value, std::numeric_limits::min()); - return value; - } - operator bool() const { - if (auto opt = TryAsBool()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } -}; - /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ -class TVMArgValue : public TVMPODValue_CRTP_ { +class TVMArgValue : public TVMPODValue_ { public: /*! \brief default constructor */ TVMArgValue() {} @@ -775,21 +663,21 @@ class TVMArgValue : public TVMPODValue_CRTP_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; + using TVMPODValue_::operator double; + using TVMPODValue_::operator int64_t; + using TVMPODValue_::operator uint64_t; + using TVMPODValue_::operator int; + using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_CRTP_::AsObjectRef; - using TVMPODValue_CRTP_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -826,15 +714,15 @@ class TVMArgValue : public TVMPODValue_CRTP_ { * * \note For internal development purpose only. */ -class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { +class TVMMovableArgValue_ : public TVMPODValue_ { public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} + TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; + using TVMPODValue_::operator double; + using TVMPODValue_::operator int64_t; + using TVMPODValue_::operator uint64_t; + using TVMPODValue_::operator int; + using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; @@ -916,7 +804,7 @@ class TVMMovableArgValueWithContext_ { * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ -class TVMRetValue : public TVMPODValue_CRTP_ { +class TVMRetValue : public TVMPODValue_ { public: /*! \brief default constructor */ TVMRetValue() {} @@ -924,28 +812,28 @@ class TVMRetValue : public TVMPODValue_CRTP_ { * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent - using TVMPODValue_CRTP_::operator double; - using TVMPODValue_CRTP_::operator int64_t; - using TVMPODValue_CRTP_::operator uint64_t; - using TVMPODValue_CRTP_::operator int; - using TVMPODValue_CRTP_::operator bool; + using TVMPODValue_::operator double; + using TVMPODValue_::operator int64_t; + using TVMPODValue_::operator uint64_t; + using TVMPODValue_::operator int; + using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_CRTP_::AsObjectRef; - using TVMPODValue_CRTP_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -1013,8 +901,8 @@ class TVMRetValue : public TVMPODValue_CRTP_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kTVMArgBool); - value_.v_bool = value; + this->SwitchToPOD(kDLInt); + value_.v_int64 = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -1086,8 +974,7 @@ class TVMRetValue : public TVMPODValue_CRTP_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || - type_code == kTVMArgBool); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -1102,9 +989,9 @@ class TVMRetValue : public TVMPODValue_CRTP_ { } // ObjectRef handling template >> + typename = typename std::enable_if::value>::type> inline TVMRetValue& operator=(TObjectRef other); - template >> + template ::value>::type> inline operator T() const; private: @@ -1132,11 +1019,9 @@ class TVMRetValue : public TVMPODValue_CRTP_ { break; } case kTVMObjectHandle: { - // We already known it is not NDArray/Module, but - // operator=(ObjectRef) also handles conversions from wrappers - // around primitive types. For NDArray/Module, the duplicate - // checks are removed with if constexpr. - operator=(other.operator ObjectRef()); + // Avoid operator ObjectRef as we already know it is not NDArray/Module + SwitchToObject(kTVMObjectHandle, + GetObjectPtr(static_cast(other.value_.v_handle))); break; } case kTVMObjectRValueRefArg: { @@ -1380,8 +1265,6 @@ inline const char* ArgTypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; - case kTVMArgBool: - return "bool"; case kDLUInt: return "uint"; case kDLFloat: @@ -1803,10 +1686,6 @@ class TVMArgsSetter { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } - TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { - values_[i].v_bool = value; - type_codes_[i] = kTVMArgBool; - } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); ICHECK_LE(value, static_cast(std::numeric_limits::max())); @@ -2072,110 +1951,38 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (!value.defined()) { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; - return; - } - - Object* ptr = value.data_.data_; - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { + if (value.defined()) { + Object* ptr = value.data_.data_; + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - return; - } - } - - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - return; - } - } - - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - return; - } - } - - // Like with BoxInt, unwrap any BoxBool instances. See the BoxInt - // explanation for more detail. - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - values_[i].v_bool = static_cast(ptr)->value; - type_codes_[i] = kTVMArgBool; - return; - } - } - - // If a boxed integer is being returned, always unbox it to the - // primitive type. This must be checked at the PackedFunc level to - // ensure that a boxed primitive argument is round-tripped correctly - // when the boxing is no longer required. - // - // For example, consider a PackedFunc with signature `ObjectRef - // func(Array)`, and returns the first element of that - // array. When passing a Python array `[5, 17.5, "hello"]`, the - // items are converted to `[Box(5), Box(17.5), - // String("hello")]` in order to provide an `Array`. - // - // If we had no additional conversions, the caller would receive the - // return value as a `Box(5)`, which would be unexpected and - // require additional unwrapping. We could perform this check - // inside the PackedFunc, but that would require a large amount of - // duplicated checked, and would require explicit handling of - // `TVMRetValue`. Instead, this conversion is checked in the FFI - // return value, to ensure that boxing/unboxing is applied - // consistently. - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - values_[i].v_int64 = static_cast(ptr)->value; - type_codes_[i] = kTVMArgInt; - return; - } - } - - // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt - // explanation for more detail. - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - values_[i].v_float64 = static_cast(ptr)->value; - type_codes_[i] = kTVMArgFloat; - return; + } else if (std::is_rvalue_reference::value) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; + } else { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } - } - - // Final fallback, if the ObjectRef has no special cases that must - // be expressed within the TVMRetValue. - if constexpr (std::is_rvalue_reference_v) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; } } -template template -inline bool TVMPODValue_CRTP_::IsObjectRef() const { +inline bool TVMPODValue_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { @@ -2205,9 +2012,8 @@ inline bool TVMPODValue_CRTP_::IsObjectRef() const { ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -template template -inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { +inline TObjectRef TVMPODValue_::AsObjectRef() const { static_assert(std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; @@ -2217,10 +2023,8 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - - // NOTE: The following code uses "if constexpr" wherever possible to - // minimize the number of runtime checks. - if constexpr (std::is_base_of_v) { + // NOTE: the following code can be optimized by constant folding. + if (std::is_base_of::value) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2229,8 +2033,7 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - - if constexpr (std::is_base_of_v) { + if (std::is_base_of::value) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2238,8 +2041,7 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - - if constexpr (std::is_base_of_v) { + if (std::is_base_of::value) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2247,7 +2049,6 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2261,152 +2062,51 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); + } else if (std::is_base_of::value && + type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } else if (std::is_base_of::value && + type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else if (std::is_base_of::value && + type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else { + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMArgInt) { - return Int(value_.v_int64); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMArgFloat) { - return Float(value_.v_float64); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMArgBool) { - return Bool(value_.v_bool); - } - } - - if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMStr || type_code_ == kTVMBytes) { - // This step is the reason why `AsObjectRef` cannot be provided - // in the base `TVMPODValue_` class. Because `TVMArgValue` and - // `TVMRetValue` have different implementations of `operator - // std::string`, with different interpretations of `kTVMStr` and - // `kTVMBytes`, we must delegate to those implementations. - // - // This could be done with a pure virtual method in - // `TVMPODValue_`, but that would require a vtable lookup during - // FFI conversions, imposing a runtime overhead. - return String(static_cast(this)->operator std::string()); - } - } - - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - - if (ptr) { - // Check for special cases of ObjectRef that have explicit - // representation within the TVMRetValue structure. - // (e.g. Unboxing of `runtime::Int` into a primitive integer - // with type code kTVMArgInt.) The checks below are written to - // handle three distinct cases. - // - // 1. If TObjectRef is a subclass of TSpecialCase, the special - // case applies, and can be handled without a runtime check. - // No runtime checks should be performed. - // - // 2. If TSpecialCase is a subclass of TObjectRef, the special - // case might apply, and requires a runtime check. - // - // 3. If neither TObjectRef nor TSpecialCase is a subclass of - // the other, then the special case does not apply. No - // runtime checks should be performed. - // - // Use of `if constexpr` ensures that the C++ subclass checks - // are applied when compiling TVM, and runtime overhead are only - // present when they may be applicable. - - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - return operator=(NDArray(std::move(other.data_))); - } - } - - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - return operator=(Module(std::move(other.data_))); - } - } - - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - if (std::is_base_of_v || - ptr->IsInstance()) { - return operator=(PackedFunc(std::move(other.data_))); - } - } - - if constexpr (std::is_base_of_v || std::is_base_of_v) { - if (std::is_base_of_v || ptr->IsInstance()) { - bool value = static_cast(ptr)->value; - return operator=(value); - } + if (ptr != nullptr) { + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(NDArray(std::move(other.data_))); + } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(Module(std::move(other.data_))); + } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(PackedFunc(std::move(other.data_))); } - - if constexpr (std::is_base_of_v || std::is_base_of_v) { - if (std::is_base_of_v || ptr->IsInstance()) { - int64_t value = static_cast(ptr)->value; - return operator=(value); - } - } - - if constexpr (std::is_base_of_v || std::is_base_of_v) { - if (std::is_base_of_v || ptr->IsInstance()) { - double value = static_cast(ptr)->value; - return operator=(value); - } - } - - // If the object being stored is not one of the special cases, - // it is stored as an ObjectRef. SwitchToObject(kTVMObjectHandle, std::move(other.data_)); - } else { - // No object is present, set to an explicitly null handle. When - // returning to a Python callee, this will be converted to - // `None`. SwitchToPOD(kTVMNullptr); value_.v_handle = nullptr; } - return *this; } @@ -2439,123 +2139,20 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - template - static String From(const PODSubclass& val) { - if (val.template IsObjectRef()) { - return val.template AsObjectRef(); + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); } else { return tvm::runtime::String(val.operator std::string()); } } -}; -template -struct PackedFuncValueConverter> { - static Array From(const TVMArgValue& val) { - auto untyped_array = val.AsObjectRef>(); - - // Attempt to convert each item of the array into the desired - // type. If the items do not require a conversion, no copies are - // made. - return untyped_array.Map([](ObjectRef item) { - // Recursively apply any conversions that have been registered - // with TVM's FFI. - // - // For example, a function that accepts `Array` may - // be called from python with argument `[1,2]`. By the time - // `PackedFuncValueConverter::From` is called, the python list - // has been converted to `Array`, with contents - // converted into `runtime::Int`. Converting the `ObjectRef` - // to `TVMArgValue` unboxes the `runtime::Int` back into a - // primitive with type code `kTVMArgInt`. This primitive can - // then be converted to a PrimExpr using - // `PackedFuncValueConverter::From`. - // - // The use of two conversions, first from python `int` to - // `runtime::Int` and then from `runtime::Int` to `PrimExpr`, - // is a result of the split between `libtvm_runtime.so` and - // `libtvm.so`. The FFI must function correctly in both - // cases, and so conversions applied by default in the Python - // FFI implementation may only produce types that are - // available in both libraries. In the C++ FFI implementation - // (i.e. this file), libtvm.so may apply additional - // conversions that are not present in libtvm_runtime.so. - TVMValue value; - int type_code; - TVMArgsSetter setter(&value, &type_code); - setter(0, item); - TVMArgValue arg(value, type_code); - return PackedFuncValueConverter::From(arg); - }); - } - static Array From(const TVMRetValue& val) { - auto untyped_array = val.AsObjectRef>(); - - return untyped_array.Map([](ObjectRef item) { - TVMRetValue item_val; - item_val = std::move(item); - return PackedFuncValueConverter::From(item_val); - }); - } -}; - -template -struct PackedFuncValueConverter> { - static Map From(const TVMArgValue& val) { - auto untyped_map = val.AsObjectRef>(); - - if (ObjectTypeChecker>::Check(untyped_map.get())) { - // Early bail-out for common case where no type conversions are - // required. - return Downcast>(untyped_map); - } - - Map output; - for (const auto& kv : untyped_map) { - T new_key = [&]() { - TVMValue pod_value; - int type_code; - TVMArgsSetter setter(&pod_value, &type_code); - setter(0, kv.first); - TVMArgValue pod_arg(pod_value, type_code); - return PackedFuncValueConverter::From(pod_arg); - }(); - U new_value = [&]() { - TVMValue pod_value; - int type_code; - TVMArgsSetter setter(&pod_value, &type_code); - setter(0, kv.second); - TVMArgValue key_arg(pod_value, type_code); - return PackedFuncValueConverter::From(key_arg); - }(); - output.Set(new_key, new_value); - } - return output; - } - static Map From(const TVMRetValue& val) { - auto untyped_map = val.AsObjectRef>(); - - if (ObjectTypeChecker>::Check(untyped_map.get())) { - // Early bail-out for common case where no type conversions are - // required. - return Downcast>(untyped_map); - } - - Map output; - for (const auto& kv : untyped_map) { - T new_key = [&]() { - TVMRetValue pod; - pod = kv.first; - return PackedFuncValueConverter::From(pod); - }(); - U new_value = [&]() { - TVMRetValue pod; - pod = kv.second; - return PackedFuncValueConverter::From(pod); - }(); - output.Set(new_key, new_value); + static String From(const TVMRetValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); } - return output; } }; @@ -2584,7 +2181,7 @@ struct PackedFuncValueConverter> { return opt.value(); } - if (auto opt = TryValueConverter(val)) { + if (auto opt = TryValueConverter(val)) { return opt.value(); } @@ -2595,10 +2192,10 @@ struct PackedFuncValueConverter> { << " but got " << ArgTypeCode2Str(val.type_code()); } - template - static Optional TryAsObjectRef(const PODSubclass& val) { - if (val.template IsObjectRef()) { - return VType(val.template AsObjectRef()); + template + static Optional TryAsObjectRef(const TVMPODValue_& val) { + if (val.IsObjectRef()) { + return VType(val.AsObjectRef()); } else if constexpr (sizeof...(VarRest)) { return TryAsObjectRef(val); } else { @@ -2606,15 +2203,15 @@ struct PackedFuncValueConverter> { } } - template + template static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const Error&) { + } catch (const InternalError&) { } if constexpr (sizeof...(VarRest)) { - return TryValueConverter(val); + return TryValueConverter(val); } else { return NullOpt; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 4c1d1fc1f3d2..d47ac94e067e 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,15 +113,7 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - // For backwards compatibility, return through TVMRetValue. - // This triggers any automatic conversions registered with - // PackedFuncValueConverter. Importantly, this allows use of - // `GetAttr` and `GetAttr` for properties that - // are stored internally as `runtime::Box` and - // `runtime::Box`. - TVMRetValue ret; - ret = (*it).second; - return ret; + return Downcast>((*it).second); } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 6b3b9c31a645..130aea32f844 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 28cb022151d2..d9b65dc8745c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1155,63 +1155,6 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm -namespace tvm { -namespace runtime { - -// Automatic conversion into PrimExpr, when called through the FFI. -// Automatic conversions into IntImm, Integer, and Bool are registered -// in "tvm/ir/expr.h", as they are currently in use outside of TIR. - -template <> -struct PackedFuncValueConverter { - template - static Optional TryFrom(const PODSubclass& val) { - auto type_code = val.type_code(); - bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || - type_code == kTVMStr || val.template IsObjectRef(); - if (can_convert) { - return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); - } else { - return NullOpt; - } - } - - template - static tvm::tir::StringImm From(const PODSubclass& val) { - if (auto opt = TryFrom(val)) { - return opt.value(); - } else { - return val.template AsObjectRef(); - } - } -}; - -template <> -struct PackedFuncValueConverter { - // Common rule for RetValue and ArgValue. Templated to ensure - // correct delegation to `operator std::string()` for either - // TVMArgValue or TVMRetValue. - template - static PrimExpr From(const PODSubclass& val) { - if (auto opt = val.TryAsBool()) { - // Check against val.TryAsBool directly, to avoid the - // bounds-checking in PackedFuncValueConverter::TryFrom. - return tvm::Bool(opt.value()); - } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { - return opt.value(); - } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { - return opt.value(); - } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { - return opt.value(); - } else { - return PrimExpr::FromObject_(val.template AsObjectRef()); - } - } -}; - -} // namespace runtime -} // namespace tvm - namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1d218c6a7c61..274ebd0a6558 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map>& param_map); +PrimFunc Specialize(PrimFunc func, const Map& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 092bd52d5634..9b23973b6f8f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -224,9 +224,8 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision = NullOpt) = 0; + virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 8f674eea2ec6..520e0e42ebbe 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,36 +60,14 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) - - # Handle return values that subclass from both TVM objects and - # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) - # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle - - # Handle return values that must be converted from the TVM object - # to a python native object. This should be used in cases where - # subclassing the python native object is forbidden. For example, - # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does - # not allow any subclasses. - # - # The `hasattr` check is done on the object's class, not the - # object itself, to avoid edge cases that can result in invalid - # error messages. If a C++ `LOG(FATAL) << nested_obj;` statement - # requires C++ to Python conversions in order to print - # `nested_obj`, then the `AttributeError` used internally by - # `hasattr` may overwrite the text being collected by - # `LOG(FATAL)`. By checking for the method on the class instead - # of the instance, we avoid throwing the `AttributeError`. - # if hasattr(type(obj), "__into_pynative_object__"): - # return obj.__into_pynative_object__() - return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6dab1a5db1f4..5f3aa04914be 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,11 +134,6 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode - elif isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - values[i].v_bool = arg - type_codes[i] = ArgTypeCode.BOOL elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -152,7 +147,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only takes in bytearray. + # from_buffer only taeks in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 45f36eafd78a..38d3cd72b55d 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -27,7 +27,6 @@ class TVMValue(ctypes.Union): _fields_ = [ ("v_int64", ctypes.c_int64), - ("v_bool", ctypes.c_bool), ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p), @@ -95,7 +94,6 @@ def _device_to_int64(dev): RETURN_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, - ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, @@ -106,7 +104,6 @@ def _device_to_int64(dev): C_TO_PY_ARG_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, - ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0f7e5fcae6bd..69e1355f7d13 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -16,7 +16,6 @@ # under the License. from ..base import raise_last_ffi_error -from libcpp cimport bool as bool_t from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -39,8 +38,7 @@ cdef enum TVMArgTypeCode: kTVMBytes = 12 kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 - kTVMArgBool = 15 - kTVMExtBegin = 16 + kTVMExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: @@ -68,7 +66,6 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct TVMValue: int64_t v_int64 - bool_t v_bool double v_float64 void* v_handle const char* v_str diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index ff38cd3d0ec2..94a9310d7815 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,17 +60,7 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle - - # Handle return values that must be converted from the TVM object - # to a python native object. This should be used in cases where - # subclassing the python native object is forbidden. For example, - # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does - # not allow any subclasses. - # if hasattr(obj, '__into_pynative_object__'): - # return obj.__into_pynative_object__) - return obj - # return obj.__into_pynative_object__() class PyNativeObject: diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 7977f37d0be5..3d1e87bf563d 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -45,7 +45,7 @@ cdef int tvm_callback(TVMValue* args, tcode == kTVMModuleHandle or tcode == kTVMNDArrayHandle or tcode == kTVMObjectRefArg or - tcode >= kTVMExtBegin): + tcode > kTVMExtBegin): CHECK_CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: @@ -118,11 +118,6 @@ cdef inline int make_arg(object arg, ptr = arg._tvm_handle value[0].v_handle = (ptr) tcode[0] = arg.__class__._tvm_tcode - elif isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - value[0].v_bool = arg - tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg tcode[0] = kInt @@ -214,8 +209,6 @@ cdef inline object make_ret(TVMValue value, int tcode): return make_ret_object(value.v_handle) elif tcode == kTVMNullptr: return None - elif tcode == kTVMArgBool: - return value.v_bool elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 03dc18ea6e0b..f148e26f3fcb 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,8 +48,7 @@ class ArgTypeCode(object): BYTES = 12 NDARRAY_HANDLE = 13 OBJECT_RVALUE_REF_ARG = 14 - BOOL = 15 - EXT_BEGIN = 16 + EXT_BEGIN = 15 class TVMByteArray(ctypes.Structure): diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py index b76202a730a2..c2e74eb1935e 100644 --- a/python/tvm/driver/tvmc/registry.py +++ b/python/tvm/driver/tvmc/registry.py @@ -20,23 +20,11 @@ from tvm.driver.tvmc import TVMCException -# We can't tell the type inside an Array but all current options are -# strings so it can default to that. runtime.BoxBool is used to -# distinguish from runtime.BoxInt. -INTERNAL_TO_NATIVE_TYPE = { - "runtime.String": str, - "runtime.BoxBool": bool, - "runtime.BoxFloat": float, - "runtime.BoxInt": int, - "Array": str, -} -INTERNAL_TO_HELP = { - "runtime.String": " string", - "runtime.BoxBool": " bool", - "runtime.BoxInt": " int", - "runtime.BoxFloat": " float", - "Array": " options", -} +# We can't tell the type inside an Array but all current options are strings so +# it can default to that. Bool is used alongside Integer but aren't distinguished +# between as both are represented by IntImm +INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} +INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} def _generate_registry_option_args(parser, registry, name): diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 6afb383c9f04..6f0a6dd7d155 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -61,7 +61,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) + return tuple(x.value for x in self.__getattr__(key)) def get_int(self, key): """Get a python int value of a key diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 263976fa98ff..c70ac2acc71b 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -20,7 +20,7 @@ import tvm._ffi -from ..runtime import Object, Scriptable +from ..runtime import Object, Scriptable, const, convert from . import _ffi_api from .base import Node, Span from .type import Type @@ -184,6 +184,9 @@ class Range(Node, Scriptable): def __init__( self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None ) -> None: + if end is None: + end = convert(begin) + begin = const(0, dtype=end.dtype, span=span) self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 51d9a013d8b3..6f76452a57b5 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,7 +28,6 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule -from tvm.script import tir as T from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -48,7 +47,7 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if isinstance(mod, PrimFunc): if not (mod.attrs and "global_symbol" in mod.attrs): mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", T.bool(True)) + mod = mod.with_attr("tir.noalias", True) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py index 502d058ffdf6..eb44696871eb 100644 --- a/python/tvm/relax/op/statistical.py +++ b/python/tvm/relax/op/statistical.py @@ -195,7 +195,7 @@ def cumprod( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: bool = False, + exclusive: Optional[bool] = None, ): """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. @@ -213,9 +213,9 @@ def cumprod( Type of the returned array and of the accumulator in which the elements are computed. If dtype is not specified, it defaults to the dtype of data. - exclusive : bool - If false (default), all elements are included in the product. If - true, the first element is excluded from the product. + exclusive : Optional[bool] + If true will return exclusive sum in which the first element is not + included. Returns ------- @@ -247,9 +247,6 @@ def cumprod( cumprod(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 1, 0, 0, 0, 0] """ - if exclusive is None: - exclusive = False - return _ffi_api.cumprod(data, axis, dtype, exclusive) # type: ignore @@ -257,7 +254,7 @@ def cumsum( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: bool = False, + exclusive: Optional[bool] = None, ): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. @@ -275,9 +272,9 @@ def cumsum( Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : bool - If false (default), all elements are included in the sum. If - true, the first element is excluded from the sum. + exclusive : Optional[bool] + If true will return exclusive sum in which the first element is not + included. Returns ------- @@ -309,9 +306,6 @@ def cumsum( cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """ - if exclusive is None: - exclusive = False - return _ffi_api.cumsum(data, axis, dtype, exclusive) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 4c670bbe74b2..1ed16363b20a 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -171,19 +171,11 @@ def visit_call_(self, op: relax.Call) -> str: def display_attrs(attr_key): attr_val = op.attrs[attr_key] - - if isinstance(attr_val, str): - # attrs can be strings but also other types; - # we want to wrap strings in quotes - # (__repr__ would work but it uses single quotes) - attr_val = wrap_quotes(attr_val) - elif isinstance(attr_val, tvm.tir.IntImm): - if attr_val.dtype == "bool": - attr_val = bool(attr_val.value) - else: - attr_val = int(attr_val.value) - - return f"{wrap_quotes(attr_key)}: {attr_val}" + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) + return f"{wrap_quotes(attr_key)}: {attr_str}" fields["attrs"] = self.build_list( map(display_attrs, op.attrs.keys()), diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index aba7ae912c54..71bf8509a63e 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -139,14 +139,14 @@ def _check_well_formed(self, mod: IRModule): # Check function attrs if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.PARAM_NUM_ATTR_KEY], (IntImm, int) + mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.STATE_NUM_ATTR_KEY], (IntImm, int) + mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index e1cab4cbd53b..9323bc40da69 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -97,9 +97,6 @@ def convert_to_expr(value: Any) -> Expr: if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) - if isinstance(value, float): - return PrimValue(tir.FloatImm("float64", value)) - tvm_value = convert_to_object(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 199193f75939..97d7cfa93c8d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -76,7 +76,7 @@ def get_section_begin_coords(split: tvm.relay.Expr) -> List[int]: # 0 is the beginning of the first section. return [0] + list(indices_or_sections) split_axis_len = input_shape[split_axis].value - section_length = split_axis_len // indices_or_sections + section_length = split_axis_len // indices_or_sections.value return list(range(0, split_axis_len, section_length)) def callback( diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index dca7b995b22d..6b9b311c83b5 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Gradient definitions for Relay operators""" -import tvm from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -384,8 +383,6 @@ def concatenate_grad(orig, grad): axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields] splits, cumsum = [], 0 for dim in axis_dims[:-1]: - if isinstance(dim, tvm.tir.IntImm): - dim = dim.value cumsum += dim splits.append(cumsum) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8bca72655491..93df67ff6b99 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1057,10 +1057,10 @@ def split_shape_func(attrs, inputs, _): return [ _split_shape_func( inputs[0], - i, - indices_or_sections, - param_is_indices, - axis, + convert(i), + convert(indices_or_sections), + convert(param_is_indices), + convert(axis), ) for i in range(num_out) ] diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index c4eff3fcc9e0..dd04d613079b 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1630,10 +1630,10 @@ def __init__(self, func_body): def convert_indices_or_sections(self, indices_or_sections): # split_v if isinstance(indices_or_sections, tvm.ir.container.Array): - values = [int(i) for i in indices_or_sections] + values = [i.value for i in indices_or_sections] # split else: - values = int(indices_or_sections) + values = indices_or_sections.value return values def is_valid(self): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index dd9c670e2a37..ef1cdb3afdd8 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,8 +18,6 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" -from typing import Optional - from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make @@ -857,14 +855,13 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = shape.data.numpy() - shape = [_expr.IntImm(str(shape.dtype), int(value)) for value in shape] - elif isinstance(shape, Expr): + shape = list(shape.data.numpy()) + if isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) - if isinstance(shape, int): shape = [shape] - + if isinstance(shape, (list, tuple)): + shape = list(shape) return _make.broadcast_to(data, shape) @@ -1941,8 +1938,9 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -def dft(re_data, im_data, inverse: Optional[bool] = False): - """Computes the discrete Fourier transform of input (calculation along the last axis). +def dft(re_data, im_data, inverse=False): + """ + Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. Parameters @@ -1954,11 +1952,8 @@ def dft(re_data, im_data, inverse: Optional[bool] = False): N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. - inverse : Optional[bool] - + inverse : bool Whether to perform the inverse discrete fourier transform. - Providing None is equivalent to False, and is maintained for - compatibility. Returns ------- @@ -1966,11 +1961,7 @@ def dft(re_data, im_data, inverse: Optional[bool] = False): The Fourier Transform of the input (Real part). im_output : relay.Expr The Fourier Transform of the input (Imaginary part). - """ - if inverse is None: - inverse = False - return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 6eef6ff3ffae..7ad838895c9f 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -364,8 +364,9 @@ def split(expr, type_map): arg = expr.args[0] t = type_map[arg] attrs = {**expr.attrs} - if isinstance(attrs["indices_or_sections"], int): - num_split = attrs["indices_or_sections"] + if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): + num_split = attrs["indices_or_sections"].value + attrs["indices_or_sections"] = num_split else: num_split = len(attrs["indices_or_sections"]) + 1 return [expr, TupleAffineType([t] * num_split)] diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 301f0ef66286..f182cd9bfd2f 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures +from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple # , BoxBool -from .object_generic import convert_to_object, convert, const +from .container import String, ShapeTuple from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index f1a0706a387d..686b4a26c80c 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -172,41 +172,3 @@ def __eq__(self, other): return False return True - - -# @tvm._ffi.register_object("runtime.BoxBool") -# class BoxBool(Object): -# """A boolean wrapped as a tvm Object - -# Parameters -# ---------- -# value: bool - -# The value to hold -# """ - -# def __init__(self, value: bool): -# # Convert to int to avoid an infinite recursion, because -# # BoxBool may be constructed in _make_tvm_args, and calling -# # the packed func `_ffi_api.BoxBool` internally calls -# # `_make_tvm_args`. -# self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) - -# def __into_pynative_object__(self) -> bool: -# return self.value - -# @property -# def value(self) -> bool: -# """Unwrap the boxed value. - -# This is implemented explicitly rather than using the usual -# PackedFunc handling or AttrVisitor mechanics for two reasons. -# First, because the PackedFunc handling would require ambiguous -# representations between `True`/`1` and `False`/`0`. Second, -# because the boxing/unboxing must be available in -# `libtvm_runtime.so`, and AttrVisitor is only available in -# `libtvm.so`. -# """ -# unboxed_bool = _ffi_api.UnBoxBool(self) -# assert unboxed_bool is not None -# return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 20909c53c787..887c2faaeb2b 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,62 +38,65 @@ def asobject(self): ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject) -def convert_to_object(value): +def convert_to_object(value, span=None): """Convert a Python value to corresponding object type. - Type conversions performed by this function must *only* produce - types that are supported by `libtvm_runtime.so`. This function - must be usable in environments where only TVM runtime support is - present. Automatic conversions to compile-time representations - (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as - part of this conversion, as these types are not available in - `libtvm_runtime.so`. - Parameters ---------- value : str The value to be inspected. + span : Optional[Span] + The location of this itervar in the source code. + Returns ------- obj : Object The corresponding object value. - """ - if isinstance(value, ObjectTypes): return value - elif isinstance(value, (bool, int, float)): - return value - elif isinstance(value, string_types): + if isinstance(value, bool): + return const(value, "uint1x1", span=span) + if isinstance(value, Number): + return const(value, span=span) + if isinstance(value, string_types): return _ffi_api.String(value) - elif isinstance(value, (list, tuple)): - # The call to _ffi_api.Array will convert its own arguments, - # so we don't need to apply any explicit conversions here. + if isinstance(value, (list, tuple)): + value = [convert_to_object(x) for x in value] return _ffi_api.Array(*value) - elif isinstance(value, dict): - if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): - raise ValueError("key of map must already been a container type") - - vlist = [kv for item in value.items() for kv in item] + if isinstance(value, dict): + vlist = [] + for item in value.items(): + if ( + not isinstance(item[0], ObjectTypes) + and not isinstance(item[0], string_types) + and not isinstance(item[0], Number) + ): + raise ValueError("key of map must already been a container type") + vlist.append(convert_to_object(item[0])) + vlist.append(convert_to_object(item[1])) return _ffi_api.Map(*vlist) - elif isinstance(value, ObjectGeneric): + if isinstance(value, ObjectGeneric): return value.asobject() - elif callable(value): + if callable(value): return convert_to_tvm_func(value) - elif value is None: + if value is None: return None - else: - raise TypeError(f"don't know how to convert type {type(value)} to object") + + raise ValueError(f"don't know how to convert type {type(value)} to object") -def convert(value): +def convert(value, span=None): """Convert value to TVM object or function. Parameters ---------- value : python value + span : Optional[Span] + The location of this statement in the source code. + Returns ------- tvm_val : Object or Function @@ -104,29 +107,29 @@ def convert(value): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ - - return convert_to_object(value) + return convert_to_object(value, span=span) def _scalar_type_inference(value): if hasattr(value, "dtype"): - return str(value.dtype) + dtype = str(value.dtype) elif isinstance(value, bool): - return "bool" + dtype = "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - return "float32" + dtype = "float32" else: - return "float64" + dtype = "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - return "int32" + dtype = "int32" else: - return "int64" + dtype = "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") + return dtype def const(value, dtype=None, span=None): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 3107354ac353..e545bc3a5e53 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -536,8 +536,6 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ value = self.eval_expr(node.value) - if value is None: - self.report_error(node, "Expression to be returned must be a PrimExpr") T.evaluate(tvm.tir.ret(value)) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 948a0d7665ff..462066106a9d 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -96,7 +96,7 @@ def _allocate_tensor(func_id, args): ) shape = args[0] for i in shape: - _internal_assert(isinstance(i, (_expr.PrimExpr, int)), "The shape should be an expression") + _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression") if n > 1: _internal_assert(isinstance(args[1], str), "The data type should be an str") _internal_assert( @@ -131,11 +131,9 @@ def len(func_id, args): def _cast(func_id, args): _internal_assert( - args.__len__() == 1, - f"Casting to {func_id} only supports a single argument", + args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), + "Only one expression can be cast", ) - # The FFI can handle any conversion of `args[0]` into PrimExpr, if - # required. return _expr.Cast(func_id, args[0]) @@ -147,7 +145,9 @@ def _cast(func_id, args): def ceil_div(func_id, args): _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") _internal_assert(args.__len__() == 2, "2 arguments expected for division!") - a, b = args + _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div") + _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div") + a, b = args[0], args[1] return (a + b - 1) // b diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index bd5a060cd01c..846ef818ea54 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -279,7 +279,7 @@ def visit_Num(self, node): return tvm.runtime.const(node.n, dtype) def visit_NameConstant(self, node): - return tvm.tir.const(node.value) + return tvm.runtime.convert(node.value) def visit_AugAssign(self, node): buf = self.visit(node.target) @@ -376,7 +376,7 @@ def visit_Subscript(self, node): args = [args] arr = self.visit(node.value) - if isinstance(arr, (Array, list, tuple)): + if isinstance(arr, Array): for i in args: if isinstance(i, numbers.Integral): arr = arr[i] diff --git a/python/tvm/te/hybrid/utils.py b/python/tvm/te/hybrid/utils.py index a515938fa524..f653b3e83d8b 100644 --- a/python/tvm/te/hybrid/utils.py +++ b/python/tvm/te/hybrid/utils.py @@ -33,9 +33,9 @@ # pylint: disable=invalid-name -np_arg_types = (numpy.ndarray, *numeric_types) -tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr, *numeric_types, list, tuple, str) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, *numeric_types) +np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) +tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): @@ -91,13 +91,19 @@ def replace(op): def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" - if all(isinstance(elem, tvm_arg_types) for elem in args): + if isinstance(args[0], tvm_arg_types): + for elem in args[1:]: + _internal_assert( + isinstance(elem, tvm_arg_types), + f"Expecting a Var, Tensor or ConstExpr instance but {type(elem)} get!", + ) return True - elif all(isinstance(elem, np_arg_types) for elem in args): - return False - else: - raise ValueError( - f"Expected arguments to be entirely TVM types, " - f"or entirely numpy types, " - f"but received {[type(elem) for elem in args]}" + + _internal_assert( + isinstance(args[0], np_arg_types), f"Expect a numpy type but {type(args[0])} get!" + ) + for elem in args[1:]: + _internal_assert( + isinstance(elem, np_arg_types), f"Expect a numpy type but {type(elem)} get!" ) + return False diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 64a282dcf755..dc2c67849925 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -53,6 +53,7 @@ def placeholder(shape, dtype=None, name="placeholder"): tensor: Tensor The created tensor """ + shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape dtype = "float32" if dtype is None else dtype return _ffi_api.Placeholder(shape, dtype, name) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 930667242e29..d435e821acf3 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -64,7 +64,16 @@ def __call__(self, *indices): f"Need to provide {ndim} index in tensor but {len(indices)} was provided" ) indices = convert_to_object(indices) - return _expr.ProducerLoad(self, indices) + args = [] + for x in indices: + if isinstance(x, _expr.PrimExpr): + args.append(x) + elif isinstance(x, _expr.IterVar): + args.append(x.var) + else: + raise ValueError("The indices must be expression") + + return _expr.ProducerLoad(self, args) def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 0c8048d24d8b..bcfbe6575d52 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -21,7 +21,6 @@ from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout -from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 37976394f831..c78bb9e7ecd0 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -41,10 +41,6 @@ from .buffer import Buffer, DataProducer -def convert(expr) -> PrimExpr: - return _ffi_api.convert(expr) - - def div_ambiguity_error() -> RuntimeError: return RuntimeError( "TVM supports multiple types of integer divisions, " diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 777d46ec7b0d..50de995a9145 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, const +from tvm.runtime import ObjectGeneric, convert, const from tvm.ir import container as _container from . import stmt as _stmt @@ -107,9 +107,7 @@ def __getitem__(self, index): def __setitem__(self, index, value): index = self._normalize_index(index) - if isinstance(value, (int, bool, float)): - value = tvm.tir.const(value) - + value = convert(value) value_element = value.dtype.split("x", maxsplit=1)[0] content_element = self._content_type.split("x", maxsplit=1)[0] if value_element != content_element: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8d9647b60049..0bc299e403c5 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -19,14 +19,13 @@ from typing import Any, Optional, Union import tvm._ffi -from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import const +from tvm.runtime import const, convert from . import _ffi_api from .buffer import Buffer -from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var +from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var def _pack_buffer(buf, span=None): @@ -182,7 +181,7 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, args, span) + return Call(dtype, func_name, convert(args), span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -207,7 +206,9 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) + return Call( + dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span + ) def call_extern(dtype, func_name, *args, span=None): @@ -232,7 +233,9 @@ def call_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span) + return Call( + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span + ) def call_llvm_intrin(dtype, name, *args, span=None): @@ -1829,10 +1832,13 @@ def dp4a(vec1, vec2, acc=0): call : PrimExpr The call expression. """ + vec1 = convert(vec1) + vec2 = convert(vec2) + acc = convert(acc) return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) -def ret(val, span=None): +def ret(val): """Create a tir return expression Parameters @@ -1840,16 +1846,14 @@ def ret(val, span=None): val : Expr The returned tir expression, whose data type is int, float or void pointer. - span : Optional[Span] - The location of this operator in the source code. - Returns ------- ret : PrimExpr The return expression """ - return _ffi_api.ret(val, span) + val = convert(val) + return call_intrin(val.dtype, "tir.ret", val) def any(*args, span=None): @@ -2034,7 +2038,7 @@ def exp(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.exp", x) @@ -2051,7 +2055,7 @@ def exp2(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.exp2", x) @@ -2068,7 +2072,7 @@ def exp10(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.exp10", x) @@ -2085,7 +2089,7 @@ def erf(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.erf", x) @@ -2102,7 +2106,7 @@ def tanh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.tanh", x) @@ -2119,7 +2123,7 @@ def sigmoid(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.sigmoid", x) @@ -2136,7 +2140,7 @@ def log(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.log", x) @@ -2153,7 +2157,7 @@ def log2(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.log2", x) @@ -2170,7 +2174,7 @@ def log10(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.log10", x) @@ -2187,7 +2191,7 @@ def log1p(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.log1p", x) @@ -2204,7 +2208,7 @@ def tan(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.tan", x) @@ -2221,7 +2225,7 @@ def cos(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.cos", x) @@ -2238,7 +2242,7 @@ def cosh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.cosh", x) @@ -2255,7 +2259,7 @@ def acos(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.acos", x) @@ -2272,7 +2276,7 @@ def acosh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.acosh", x) @@ -2289,7 +2293,7 @@ def sin(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.sin", x) @@ -2306,7 +2310,7 @@ def sinh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.sinh", x) @@ -2323,7 +2327,7 @@ def asin(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.asin", x) @@ -2340,7 +2344,7 @@ def asinh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.asinh", x) @@ -2357,7 +2361,7 @@ def atan(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.atan", x) @@ -2374,7 +2378,7 @@ def atanh(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.atanh", x) @@ -2394,8 +2398,8 @@ def atan2(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.atan2", x1, x2) @@ -2412,7 +2416,7 @@ def sqrt(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.sqrt", x) @@ -2429,7 +2433,7 @@ def rsqrt(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.rsqrt", x) @@ -2675,8 +2679,8 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore @@ -2696,8 +2700,8 @@ def hypot(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore @@ -2717,8 +2721,8 @@ def copysign(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore @@ -2738,8 +2742,8 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore @@ -2858,7 +2862,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(x, y, span) # type: ignore + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore def pow(x, y, span=None): @@ -2880,7 +2884,7 @@ def pow(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(x, y, span) # type: ignore + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore def popcount(x): @@ -2896,7 +2900,7 @@ def popcount(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = convert(x) return call_intrin(x.dtype, "tir.popcount", x) @@ -3028,8 +3032,8 @@ def fmod(x, y): z : PrimExpr The result. """ - x = tir.convert(x) - y = tir.convert(y) + x = convert(x) + y = convert(y) return call_intrin(x.dtype, "tir.fmod", x, y) @@ -3063,7 +3067,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(cond, t, f, span) # type: ignore + return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore def div(a, b, span=None): @@ -3310,23 +3314,34 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None, init=None): code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - expr = tir.convert(expr) + expr = convert(expr) if init is not None: - init = tir.convert(init) + init = convert(init) if isinstance(expr, Array): size = len(expr) - lhs = [] - rhs = [] + larr = [] + rarr = [] dtypes = [] for i in range(size): dtype = expr[i].dtype dtypes.append(dtype) lname = code.co_varnames[0] + "_" + str(i) - lhs.append(Var(lname, dtype)) + larr.append(Var(lname, dtype)) rname = code.co_varnames[1] + "_" + str(i) - rhs.append(Var(rname, dtype)) - if init is None: - init = [] + rarr.append(Var(rname, dtype)) + if init is not None: + init = convert(init) + assert isinstance(init, Array) + assert len(init) == size + for init_i in range(size): + init_i = convert(init_i) + assert isinstance( + init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm) + ) + else: + init = convert([]) + lhs = convert(larr) + rhs = convert(rarr) result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: @@ -3337,18 +3352,22 @@ def _make_reduce(expr, axis, where=None, init=None): rvar = Var(code.co_varnames[1], dtype) result = [fcombine(lvar, rvar)] id_elem = [fidentity(dtype)] - lhs = [lvar] - rhs = [rvar] - expr = [expr] + lhs = convert([lvar]) + rhs = convert([rvar]) + expr = convert([expr]) if init is not None: - init = [init] + assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)) + init = convert([init]) + result = convert(result) + id_elem = convert(id_elem) combiner = CommReducer(lhs, rhs, result, id_elem) - if not isinstance(axis, (list, tuple, tvm.ir.Array)): - axis = [axis] + axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) if where is None: - where = tir.convert(True) + where = convert(True) if init is None: - outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size)) + outputs = tuple( + tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size) + ) else: outputs = tuple( tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index 85377560f1fc..cb8d5ce9973e 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -39,20 +39,17 @@ def _json_from_tvm(obj): if obj is None: return None - elif isinstance(obj, (bool, int, float, str)): - return obj - elif isinstance(obj, Array): + if isinstance(obj, Array): return [_json_from_tvm(i) for i in obj] - elif isinstance(obj, Map): + if isinstance(obj, Map): return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} - elif isinstance(obj, String): + if isinstance(obj, String): return str(obj) - elif isinstance(obj, (IntImm, FloatImm)): + if isinstance(obj, (IntImm, FloatImm)): return obj.value - elif isinstance(obj, IndexMap): + if isinstance(obj, IndexMap): return save_json(obj) - else: - raise TypeError("Not supported type: " + str(type(obj))) + raise TypeError("Not supported type: " + str(type(obj))) @_register_object("tir.Trace") diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index cc1a28b9dee0..bf6a9c75516f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -468,7 +468,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype - use_scalable_vectors = bool(out.op.attrs["use_scalable_vectors"]) + use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value tile_M, tile_K = arm_utils.get_tiling_A(False, in_type) tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 0a7acfa50444..83b000a4b9bb 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -295,11 +295,15 @@ def batch_matmul_int8( # pad for _dp4a vectorize pad_x = te.compute( (XB, M, nK), - lambda b, i, j: tvm.te.if_then_else(j >= XK, tvm.tir.const(0, x.dtype), x[b, i, j]), + lambda b, i, j: tvm.te.if_then_else( + j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j] + ), ) pad_y = te.compute( (YB, N, nK), - lambda b, i, j: tvm.te.if_then_else(j >= YK, tvm.tir.const(0, y.dtype), y[b, i, j]), + lambda b, i, j: tvm.te.if_then_else( + j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j] + ), ) out = te.compute( diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index b98d9c102baa..8d59c2a035a9 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -48,7 +48,7 @@ pub struct ModuleNode { crate::external! { #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> bool; + fn runtime_enabled(target: CString) -> i32; #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; @@ -121,7 +121,8 @@ impl Module { /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); - runtime_enabled(target).unwrap() + let enabled = runtime_enabled(target).unwrap(); + enabled != 0 } /// Returns the underlying module handle. diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 2c1f7db6adb0..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -73,7 +73,6 @@ macro_rules! TVMPODValue { Int(i64), UInt(i64), Float(f64), - Bool(bool), Null, DataType(DLDataType), String(*mut c_char), @@ -96,7 +95,6 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -119,7 +117,6 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), @@ -266,7 +263,6 @@ macro_rules! impl_pod_value { impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); -impl_pod_value!(Bool, bool, [bool]); impl_pod_value!(DataType, DLDataType, [DLDataType]); impl_pod_value!(Device, DLDevice, [DLDevice]); @@ -384,6 +380,37 @@ impl TryFrom for std::ffi::CString { } } +// Implementations for bool. + +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() + } +} + +impl From for RetValue { + fn from(s: bool) -> Self { + (s as i64).into() + } +} + +impl TryFrom for bool { + type Error = ValueDowncastError; + + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> bool, + |RetValue::Int(val)| { !(val == 0) }) + } +} + +impl<'a> TryFrom> for bool { + type Error = ValueDowncastError; + + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) + } +} + impl From<()> for RetValue { fn from(_: ()) -> Self { RetValue::Null diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 82e439cddbc2..e03d4302c89f 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -554,19 +554,9 @@ class FlopEstimator : public ExprFunctor { if (auto pop = op.as()) { if (pop->attrs.count("FLOP")) { // Use user-provided FLOP - ObjectRef annotation = pop->attrs["FLOP"]; - auto value = [&]() -> int64_t { - if (auto runtime_int = annotation.as()) { - return runtime_int->value; - } else if (auto int_imm = annotation.as()) { - return int_imm->value; - } else { - LOG(FATAL) << "FLOP annotation must be an integer, " - << "but was an object of type " << annotation->GetTypeKey(); - } - }(); - - ret += value; + auto pint = pop->attrs["FLOP"].as(); + ICHECK(pint != nullptr); + ret += pint->value; } else { // Estimate by parsing the compute body double num_element = AxisLengthProd(pop->axis); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 0bf6da255d2a..862e593c9dd3 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -482,8 +482,7 @@ std::vector> RuleCustomSketch::Apply(const SketchPolicyNod std::vector> ret; for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); - auto next = item[1].as(); - ICHECK(next); + auto next = item[1].as(); ret.emplace_back(Downcast(item[0]), next->value); } return ret; diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index cc6b0ab23756..76fb77dd9527 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -101,7 +101,7 @@ inline int OperationToStage(const te::Operation& op, const State& state) { /*! \brief Get an integer from a tvm str Map. */ inline int GetIntParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); + auto pint = attr_dict[key].as(); ICHECK(pint != nullptr); return pint->value; } @@ -109,7 +109,7 @@ inline int GetIntParam(const Map& attr_dict, const std::strin /*! \brief Get a double from a tvm str Map. */ inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); + auto pdouble = attr_dict[key].as(); ICHECK(pdouble != nullptr); return pdouble->value; } @@ -120,12 +120,10 @@ inline std::string GetStringParam(const Map& attr_dict, const const auto& target = attr_dict[key]; if (auto pstr = target.as()) { return pstr->value; - } else if (auto pstr = target.as()) { - return pstr->data; - } else { - LOG(FATAL) << "Could not convert object " << target << " of type " << target->GetTypeKey() - << " to string"; } + auto pstr = target.as(); + ICHECK(pstr != nullptr); + return pstr->data; } /*! \brief Get a iterator name set from a tvm str Map. */ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 708fb56c9851..289c1b79fd66 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -100,17 +100,8 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { const ObjectRef& value = doc->value; if (!value.defined()) { output_ << "\"\""; - } else if (const auto* runtime_int = value.as()) { - output_ << runtime_int->value; } else if (const auto* int_imm = value.as()) { output_ << int_imm->value; - } else if (const auto* runtime_float = value.as()) { - output_.precision(config_.float_precision); - if (std::isinf(runtime_float->value) || std::isnan(runtime_float->value)) { - output_ << '"' << runtime_float->value << '"'; - } else { - output_ << runtime_float->value; - } } else if (const auto* float_imm = value.as()) { output_.precision(config_.float_precision); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 99be910bd70a..7e96c657a711 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -33,10 +33,6 @@ namespace msc { LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { if (obj.as()) { return LiteralDoc::Str(Downcast(obj), NullOpt); - } else if (auto ptr = obj.as()) { - return LiteralDoc::Int(ptr->value, NullOpt); - } else if (auto ptr = obj.as()) { - return LiteralDoc::Float(ptr->value, NullOpt); } else if (obj.as()) { return LiteralDoc::Int(Downcast(obj)->value, NullOpt); } else if (obj.as()) { diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 5fcbe924ae1c..f58f95ae53b0 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -263,10 +263,6 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = ""; } else if (obj.as()) { obj_string = Downcast(obj); - } else if (const auto* n = obj.as()) { - obj_string = std::to_string(n->value); - } else if (const auto* n = obj.as()) { - obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1e576bc91002..105ac063e0ea 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -171,10 +171,9 @@ Array CreatePassList(bool disable_loop_partition) { // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { - auto phase_num = phase_pass[0].as(); + const IntImmNode* phase_num = phase_pass[0].as(); ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " - << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 08e7ffc5bf59..f197ac4416fa 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -31,91 +31,6 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -namespace { - -/* \brief Normalize attributes from runtime types to Relax IR types - * - * While conversion from `tvm::runtime` types to compile-time IR - * types usually occurs as part of FFI conversions, the attributes - * are not converted, as they are stored in a `Map`. While this is required to allow attribute values to - * contain `ObjectRef` instances that are not IR expressions, the - * conversion should still be applied when possible. - * - * \param obj The IR attribute value to be normalized - * - * \return The normalized attribute value - */ -ObjectRef NormalizeAttr(ObjectRef obj) { - if (auto dict_attrs = obj.as()) { - auto new_dict = Downcast>(NormalizeAttr(dict_attrs->dict)); - if (new_dict.same_as(dict_attrs->dict)) { - return obj; - } else { - return DictAttrs(new_dict); - } - } else if (auto runtime_bool = obj.as()) { - return Bool(runtime_bool->value); - } else if (auto runtime_int = obj.as()) { - return Integer(runtime_int->value); - } else if (auto opt_array = obj.as>()) { - return opt_array.value().Map([](const ObjectRef& inner) { return NormalizeAttr(inner); }); - } else if (auto opt_map = obj.as>()) { - auto map = opt_map.value(); - - Map updates; - for (const auto& [key, inner] : map) { - auto new_inner = NormalizeAttr(inner); - if (!new_inner.same_as(inner)) { - updates.Set(key, new_inner); - } - } - for (const auto& [key, new_inner] : updates) { - map.Set(key, new_inner); - } - - return map; - - } else { - return obj; - } -} -} // namespace - -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { - if (new_attrs.empty()) { - return attrs; - } - - auto* write_ptr = attrs.CopyOnWrite(); - Map attr_dict = std::move(write_ptr->dict); - - for (const auto& [key, value] : new_attrs) { - attr_dict.Set(key, NormalizeAttr(value)); - } - - write_ptr->dict = std::move(attr_dict); - return attrs; -} - -DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value) { - auto* write_ptr = attrs.CopyOnWrite(); - Map attr_dict = std::move(write_ptr->dict); - attr_dict.Set(key, NormalizeAttr(value)); - - write_ptr->dict = std::move(attr_dict); - return attrs; -} - -DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { - auto* write_ptr = attrs.CopyOnWrite(); - Map attr_dict = std::move(write_ptr->dict); - attr_dict.erase(key); - - write_ptr->dict = std::move(attr_dict); - return attrs; -} - void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; @@ -128,15 +43,11 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un dict.Set(key, val.operator PrimExpr()); } } - - dict = Downcast>(NormalizeAttr(dict)); } Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { - dict = Downcast>(NormalizeAttr(dict)); - ObjectPtr n = make_object(); n->dict = std::move(dict); data_ = std::move(n); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index ded046eafc5d..596805f74b24 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -47,12 +47,6 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { if (auto opt = ref.as()) { return tir::StringImm(opt.value()); } - if (auto opt = ref.as()) { - return Bool(opt.value()); - } - if (auto opt = ref.as()) { - return Integer(opt.value()); - } if (const auto* buffer_region = ref.as()) { Array indices; indices.reserve(buffer_region->region.size()); @@ -161,14 +155,9 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range") - .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { - if (end.defined()) { - return Range(begin, end.value(), span); - } else { - return Range(IntImm(begin->dtype, 0), begin, span); - } - }); +TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Range(args[0], args[1], args[2]); +}); TVM_REGISTER_NODE_TYPE(RangeNode); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index f0b879acbc03..dc67822411c5 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -107,42 +107,43 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, uint32_t value_type_index, - std::function legalization) { + void Register(std::string key, uint32_t value_type_index) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_index = value_type_index; info.type_key = runtime::Object::TypeIndex2Key(value_type_index); - info.legalization = legalization; key2vtype_[key] = info; } // Trying to validate and legalize a config. void Legalize(Map* config) { std::vector> update; - for (auto [key, obj] : *config) { - auto it = key2vtype_.find(key); + auto* reflection = ReflectionVTable::Global(); + + for (auto kv : *config) { + auto it = key2vtype_.find(kv.first); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option \'" << key << "\' candidates are:"; + os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; int counter = 0; - for (const auto& [key, obj] : key2vtype_) { + for (const auto& kv : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; - os << key; + os << kv.first; } LOG(FATAL) << os.str(); } const auto& info = it->second; - - ICHECK(obj.defined()) << "AttributeError: " << key << " is None"; - - ICHECK(info.legalization) << "AttributeError: " - << "Config option \'" << key - << "\' was defined without a legalization function."; - auto legalized = info.legalization(obj); - if (!legalized.same_as(obj)) { - update.emplace_back(key, legalized); + ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; + if (kv.second->IsInstance::ContainerType>()) { + ObjectRef converted = + reflection->CreateObject(info.type_key, Downcast>(kv.second)); + update.emplace_back(kv.first, converted); + } else { + if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { + LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " + << info.type_key << " but get " << kv.second->GetTypeKey(); + } } } for (auto&& kv : update) { @@ -169,15 +170,13 @@ class PassConfigManager { struct ValueTypeInfo { std::string type_key; uint32_t type_index; - std::function legalization; }; std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index, - std::function legalization) { - PassConfigManager::Global()->Register(key, value_type_index, legalization); +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { + PassConfigManager::Global()->Register(key, value_type_index); } Map> PassContext::ListConfigs() { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index ce025540e496..416753871244 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -39,14 +39,8 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { } else { os << int_imm->value; } - } else if (const auto* runtime_bool = json_obj.as()) { - os << (runtime_bool->value ? "true" : "false"); - } else if (const auto* runtime_int = json_obj.as()) { - os << runtime_int->value; } else if (const auto* float_imm = json_obj.as()) { os << std::setprecision(20) << float_imm->value; - } else if (const auto* runtime_float = json_obj.as()) { - os << std::setprecision(20) << runtime_float->value; } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { @@ -171,7 +165,7 @@ class JSONTokenizer { std::string to_parse(st, cur_); if (!is_float) { try { - *token = Token{TokenType::kInteger, runtime::Int(std::stoll(to_parse))}; + *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))}; } catch (const std::invalid_argument& e) { LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse << ". Details: " << e.what() << ". Switching to std::stod now."; @@ -184,7 +178,7 @@ class JSONTokenizer { } if (is_float) { try { - *token = Token{TokenType::kFloat, runtime::Float(std::stod(to_parse))}; + *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))}; } catch (const std::invalid_argument& e) { LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse << ". Details: " << e.what(); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 63af4a684567..53f680f0a666 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -192,9 +192,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, try { const ArrayNode* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); - int64_t workload_index = Downcast(arr->at(0)); - ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); - workload = workloads[workload_index]; + workload = workloads[Downcast(arr->at(0)).IntValue()]; records[task_id] = TuningRecord::FromJSON(arr->at(1), workload); } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 5b3e6d251d56..f5d89a85092b 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -137,7 +137,7 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)]); + int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; std::vector probs = support::AsVector(Downcast>(sample_inst->attrs[1])); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index a78b829e34ab..ea4e81c16f0c 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -129,13 +129,13 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { ICHECK_EQ(inst->attrs.size(), 2); - std::vector probs = support::AsVector( - Downcast>(inst->attrs[1])); + std::vector probs = + support::AsVector(Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; } - const auto* d = TVM_TYPE_AS(decision, runtime::Int::ContainerType); + const auto* d = TVM_TYPE_AS(decision, IntImmNode); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 36dc57d80e66..7bbf00343af3 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -114,9 +114,9 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = GetRef(sample_inst); candidate->decision = - Downcast(trace->decisions[GetRef(sample_inst)])->value; - candidate->probs = support::AsVector( - Downcast>(sample_inst->attrs[1])); + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); return true; } diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index 110cae96cb53..b651b1f401cb 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -34,11 +34,11 @@ using namespace tvm::tir; std::function MakeFactorSampler(Schedule sch, Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { - extents.push_back(runtime::Int(extent->value)); + extents.push_back(extent); } } int n = extents.size(); @@ -48,7 +48,7 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, runtime::Float(1.0 / n)); + Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); return sch->SampleCategorical(extents, probs); }; } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 4a304cefa6bb..e8d821636fd3 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -73,7 +73,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, 1.0 / n_candidate); + Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -267,7 +267,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + Array thread_extents; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("max_threads_per_block", &max_threads_per_block); @@ -279,8 +279,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { - for (const auto& extent : thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const Integer& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } ObjectPtr n = make_object(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 2979e4229bdd..bcaf4343e256 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -383,8 +383,9 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = (*sch)->SampleCategorical( - support::AsArray(valid_vector_lens), Array(n, prob)); + tir::ExprRV vector_load_len = + (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), + Array(n, FloatImm(DataType::Float(64), prob))); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 8ea2c2d1c6c3..045aa85b73ad 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -68,7 +68,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, runtime::Float(prob)); + Array probs(n, FloatImm(DataType::Float(64), prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -102,7 +102,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -122,7 +122,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + Array unroll_max_steps, bool unroll_explicit) { ObjectPtr n = make_object(); n->max_jobs_per_core = max_jobs_per_core; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 83f5d073cb32..3be264332461 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -79,7 +79,7 @@ Array ScheduleRule::DefaultLLVM() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -126,7 +126,7 @@ Array ScheduleRule::DefaultX86(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -158,11 +158,11 @@ Array ScheduleRule::DefaultCUDA() { /*require_ordered=*/false, /*disallow_op=*/Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, @@ -297,7 +297,7 @@ Array ScheduleRule::DefaultHexagon() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } @@ -410,7 +410,7 @@ Array ScheduleRule::DefaultARM(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 28c45ea7455d..ceb0356cbcfe 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -424,22 +424,13 @@ inline Array AsFloatArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - auto float_value = [&]() -> double { - if (const auto* int_imm = elem.as()) { - return int_imm->value; - } else if (const auto* runtime_int = elem.as()) { - return runtime_int->value; - } else if (const auto* float_imm = elem.as()) { - return float_imm->value; - } else if (const auto* runtime_float = elem.as()) { - return runtime_float->value; - } else { - LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " - << elem->GetTypeKey(); - } - }(); - - results.push_back(FloatImm(DataType::Float(32), float_value)); + if (const auto* int_imm = elem.as()) { + results.push_back(FloatImm(DataType::Float(32), int_imm->value)); + } else if (const auto* float_imm = elem.as()) { + results.push_back(FloatImm(DataType::Float(32), float_imm->value)); + } else { + LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); + } } return results; } @@ -455,16 +446,11 @@ inline Array AsIntArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - auto int_value = [&]() -> int64_t { - if (const auto* int_imm = elem.as()) { - return int_imm->value; - } else if (const auto* runtime_int = elem.as()) { - return runtime_int->value; - } else { - LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); - } - }(); - results.push_back(Integer(int_value)); + if (const auto* int_imm = elem.as()) { + results.push_back(Integer(int_imm->value)); + } else { + LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); + } } return results; } diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc deleted file mode 100644 index 86596fb5ce29..000000000000 --- a/src/node/boxed_primitive.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file node/boxed_primitive.cc - * - * \brief Reflection utilities for runtime-supported classes - * - * The fundamental support for boxing and unboxing of primitives - * during FFI calls is implemented in runtime/boxed_primitive.cc. In - * addition, boxed primitives may be registered with compile-time - * utilities (e.g. reflection, JSON import/export) that can provide - * additional functionality and improved debugging ability. However, - * neither these compile-time utilities nor any registration of - * `Box` into the compile-time utilities should be included as - * part of `libtvm_runtime.so`. - * - * This file contains the registration of the `libtvm_runtime.so` - * class `Box` for utilities that are contained in `libtvm.so`. - */ -#include -#include -#include -#include - -namespace tvm { -namespace runtime_ext { - -using runtime::Box; -using runtime::BoxNode; - -/* \brief Compile-time extension trait for runtime types - * - * Extends the use of boxed primitive during TVM's compilation step. - * - * Most TVM classes define these functions as part of the class - * definition. However, the boxed primitives must be usable at - * runtime, and so the class definition may only refer to types that - * are present in `libtvm_runtime.so`. - */ -template -struct BoxNodeCompileTimeTraits { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { - hash_reduce(node->value); - } - - static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, - SEqualReducer equal) { - return equal(lhs->value, rhs->value); - } -}; - -TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) - .set_creator([](const std::string& blob) -> ObjectPtr { - int64_t value = std::atoll(blob.c_str()); - return make_object>(value); - }) - .set_repr_bytes([](const Object* n) -> std::string { - int64_t value = GetRef(n).as>().value()->value; - std::stringstream ss; - ss << value; - return ss.str(); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { - auto box = Downcast>(node); - p->stream << box->GetTypeKey() << "(" << box->value << ")"; - }); - -TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) - .set_creator([](const std::string& blob) -> ObjectPtr { - if (blob == "true") { - return make_object>(true); - } else if (blob == "false") { - return make_object>(false); - } else { - LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; - } - }) - .set_repr_bytes([](const Object* n) -> std::string { - bool value = GetRef(n).as>().value()->value; - if (value) { - return "true"; - } else { - return "false"; - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { - auto box = Downcast>(node); - p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; - }); - -TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) - .set_creator([](const std::string& blob) -> ObjectPtr { - double value = std::atof(blob.c_str()); - return make_object>(value); - }) - .set_repr_bytes([](const Object* n) -> std::string { - double value = GetRef(n).as>().value()->value; - std::stringstream ss; - ss << value; - return ss.str(); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { - auto box = Downcast>(node); - p->stream << box->GetTypeKey() << "(" << box->value << ")"; - }); - -} // namespace runtime_ext - -} // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index b8918b4ea48c..6e7d82ee4a59 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -57,7 +57,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -81,16 +81,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -107,13 +107,13 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } if (auto v = config_dict.Get("show_all_struct_info")) { - n->show_all_struct_info = Downcast(v)->value; + n->show_all_struct_info = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 614669a412d0..379a75f6109b 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -65,22 +65,6 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, return fsequal_reduce_[tindex](self, other, equal); } -namespace { -ObjectPath GetAttrPath(const ObjectRef& obj, const void* attr_address, const ObjectPath& path) { - if (obj->IsInstance() || - obj->IsInstance() || - obj->IsInstance()) { - // Special case for containers that contain boxed primitives. The - // "value" attribute containing the boxed value should not be part - // of the reported mismatched path. - return path; - } else { - Optional attr_key = GetAttrKeyByAddress(obj.get(), attr_address); - return path->Attr(attr_key); - } -} -} // namespace - struct SEqualReducer::PathTracingData { ObjectPathPair current_paths; ObjectRef lhs_object; @@ -88,9 +72,10 @@ struct SEqualReducer::PathTracingData { Optional* first_mismatch; ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { - ObjectPath lhs_attr_path = GetAttrPath(lhs_object, &lhs, current_paths->lhs_path); - ObjectPath rhs_attr_path = GetAttrPath(rhs_object, &rhs, current_paths->rhs_path); - return ObjectPathPair(lhs_attr_path, rhs_attr_path); + Optional lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs); + Optional rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs); + return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key), + current_paths->rhs_path->Attr(rhs_attr_key)); } }; @@ -113,12 +98,13 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { /* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { - ObjectPath lhs_attr_path = - GetAttrPath(tracing_data->lhs_object, lhs_address, tracing_data->current_paths->lhs_path); - ObjectPath rhs_attr_path = - GetAttrPath(tracing_data->rhs_object, rhs_address, tracing_data->current_paths->rhs_path); - - *tracing_data->first_mismatch = ObjectPathPair(lhs_attr_path, rhs_attr_path); + Optional lhs_attr_key = + GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); + Optional rhs_attr_key = + GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); + *tracing_data->first_mismatch = + ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key), + tracing_data->current_paths->rhs_path->Attr(rhs_attr_key)); } } @@ -214,6 +200,7 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } // Slow path: tracing object paths for better error reporting + ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 1c795594629e..334e6e5c9a62 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,7 +45,6 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; -namespace { // Helper function to get the function name of the registered packed function implementation of // relax operator. FCallPacked GetPackedFuncName(const Call& call) { @@ -58,7 +57,6 @@ FCallPacked GetPackedFuncName(const Call& call) { } return {}; } -} // namespace /*! * \brief A class to generate VM executable for Relax functions. diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 5e6a1c3f8442..dd34bc63bb31 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,21 +44,6 @@ namespace relax_vm { using vm::VMFuncInfo; -namespace { -// Helper function to get the function name of the registered packed function implementation of -// relax operator. -FCallPacked GetPackedFuncName(const Call& call) { - static auto op_map = Op::GetAttrMap("FCallPacked"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op]; - } - } - return {}; -} -} // namespace - /*! * \brief A class to generate VMTIR for Relax functions. * @@ -247,14 +232,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - // special case generate for the intrinsics whose attribute fields - // cannot be represented by args in the CallNode - FCallPacked name = GetPackedFuncName(call); - if (name.size()) { - // If the operator has a registered packed function implementation, emit call to that packed - // function. - EmitCallPacked(name, VisitArray(call->args), dst_reg); - } else if (call_node->op == call_builtin_with_ctx_op_) { + if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg); @@ -282,8 +260,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(), - {tir::StringImm("vm.builtin.read_if_cond"), cond_value}); + // turn ndarray cond value into scalar. + cond_value = tir::Cast(DataType::Bool(), + tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); tir::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 7aca1470aee4..fd6fea6e703c 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(Variant> shape, Expr fill_value, DataType dtype) { +Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 6e7c8255238a..989eaa12fdbf 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(Variant> shape, Expr fill_value, DataType dtype); +Expr full(ObjectRef shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2b1c6eafb652..07c90756bf90 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -654,7 +654,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -747,7 +747,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant return ShapeExpr(array_ref); } -Expr reshape(Expr x, Variant> shape) { +Expr reshape(Expr x, ObjectRef shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -812,7 +812,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, Variant> indices_or_sections, int axis) { +Expr split(Expr x, ObjectRef indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 68622f1359e0..32aa10776894 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -90,7 +90,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, Variant> shape); +Expr reshape(Expr x, ObjectRef shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -105,7 +105,7 @@ Expr reshape(Expr x, Variant> shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, Variant> indices_or_sections, int axis); +Expr split(Expr x, ObjectRef indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc index 345e2d0e60da..61b6c9ce897f 100644 --- a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc +++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc @@ -40,7 +40,7 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) { String mcpu = cfg.value()->mcpu; Array mattr = {cfg.value()->mattr}; - runtime::Bool debug_last_error = cfg.value()->debug_last_error->value; + Bool debug_last_error = cfg.value()->debug_last_error; Target cmsis_nn_target(TargetJSON{ {"kind", String("cmsis-nn")}, diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 00581a089a4a..10125bf814ad 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -37,7 +37,7 @@ using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc>("mattr") .add_attr_option("mcpu") - .add_attr_option("debug_last_error") + .add_attr_option("debug_last_error") .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index ea040f6ff56a..50c8b84a9069 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -39,32 +39,32 @@ namespace cutlass { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) .set_attr("RelayToTIR", CompileForCutlass()) // An integer specifying the compute capability. For example, 75 for Turing and // 80 or 86 for Ampere. - .add_attr_option("sm", runtime::Int(80)) + .add_attr_option("sm", Integer(80)) // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for // fp32 inputs on tensorcore. - .add_attr_option("use_3xtf32", runtime::Bool(true)) + .add_attr_option("use_3xtf32", Bool(true)) // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in // parallel across split-K blocks, and a separate global reduction kernel is launched to // accumulate partial reductions. The profiler will pick the best split-k factor from the // given candidate list. Note that the larger split-K factor requires a larger workspace. // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d // kinds, split_k_slices is ignored. - .add_attr_option>("split_k_slices", Array{runtime::Int(1)}) + .add_attr_option>("split_k_slices", Array({1})) // When True, profile all kernel variants with smaller alignments than the largest possible. - .add_attr_option("profile_all_alignments", runtime::Bool(false)) + .add_attr_option("profile_all_alignments", Bool(false)) // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel // is found. - .add_attr_option("find_first_valid", runtime::Bool(false)) + .add_attr_option("find_first_valid", Bool(false)) // Whether to compile profiler executables for different kernels in parallel. - .add_attr_option("use_multiprocessing", runtime::Bool(false)) + .add_attr_option("use_multiprocessing", Bool(false)) // Number of threads to use during compilation, or -1 to use number of cpus. - .add_attr_option("threads", runtime::Int(-1)) + .add_attr_option("threads", Integer(-1)) // Whether to replace sigmoid with tanh. - .add_attr_option("use_fast_math", runtime::Bool(false)) + .add_attr_option("use_fast_math", Bool(false)) // A temporary directory where intermediate compiled artifacts will be stored. .add_attr_option("tmp_dir", String("./tmp")); diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index 0f539d96e919..a3f3e6e1eb6e 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -687,14 +687,14 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) { sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat, params->input_info.m_QuantizationInfo); params->split_info.m_Axis = attrs->axis; - if (const auto* sections_ptr = attrs->indices_or_sections.as()) { - auto sections = sections_ptr->value; + if (attrs->indices_or_sections->IsInstance()) { + auto sections = Downcast(attrs->indices_or_sections)->value; int size = input_tensor_shape[attrs->axis] / sections; for (int i = 0; i < sections; i++) { params->split_info.m_Sizes.push_back(size); } } else { - auto indices = Downcast>(attrs->indices_or_sections); + auto indices = Downcast>(attrs->indices_or_sections); int last_index = 0; for (const auto& i : indices) { params->split_info.m_Sizes.push_back(i->value - last_index); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 300372838416..54d0595c4634 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -307,7 +307,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { Array compile_artifacts; for (const auto& kv : mod->functions) { const tir::PrimFunc& prim_func = Downcast(kv.second); - auto params = prim_func->GetAttr>("ethos-u.constants"); + Optional> params = + prim_func->GetAttr>("ethos-u.constants"); ICHECK(params) << "microNPU params should be present"; auto primfunc_to_artifact_pf = tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index d87447f863e2..23a873b2d392 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -97,7 +97,7 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr CreateSplitReshapedTensors(const Expr& input, const Array& original_args) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; Array rets; int total_size = 0; @@ -132,7 +132,7 @@ class ExternalFuncIOHandler : public ExprRewriter { if (func->params.size() > 1) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; auto func_name = gv->name_hint; int total_size = 0; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index de9c81a2706e..b45987f6be33 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -38,6 +38,6 @@ TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) - .add_attr_option("example_attribute", Integer(0)); + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 1dd5e3a4d772..f4babad50a3e 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -177,12 +177,12 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector indices_or_sections; std::vector mode; std::vector axis = {std::to_string(split_attr->axis)}; - if (const auto* sections = split_attr->indices_or_sections.as()) { + if (const auto* sections = split_attr->indices_or_sections.as()) { mode.emplace_back("sections"); indices_or_sections.emplace_back(std::to_string(sections->value)); } else { mode.emplace_back("indices"); - auto indices = Downcast>(split_attr->indices_or_sections); + auto indices = Downcast>(split_attr->indices_or_sections); for (const auto& i : indices) { indices_or_sections.emplace_back(std::to_string(i->value)); } diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index a62dc25e329c..0277787a8c12 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -38,30 +38,30 @@ namespace tensorrt { * - Runtime: src/runtime/contrib/tensorrt/... */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) .set_attr("RelayToTIR", CompileForTensorRT()) // A array of three integers given the major, minor, and patch numbers for the supported // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. - .add_attr_option>("tensorrt_version", Array()) + .add_attr_option>("tensorrt_version", Array()) // If true, the first tensor dimension for most operators is allowed to be Any and // TensorRT will assume it represents a batch dimension only known at inference time. // Fewer Relay operators are supported in implicit batch mode. Default true. - .add_attr_option("use_implicit_batch", runtime::Bool(true)) + .add_attr_option("use_implicit_batch", Bool(true)) // If true, excludes sub-graphs which do not have multiply-accumulate operations, even though // TensorRT supports them. ad. This is a simple heuristic to optimize the partitioning between // TensorRT and TVM. Not required if using Collage for partitioning. Defalut false. - .add_attr_option("remove_no_mac_subgraphs", runtime::Bool(false)) + .add_attr_option("remove_no_mac_subgraphs", Bool(false)) // How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. // Default 1G. - .add_attr_option("max_workspace_size", runtime::Int(1 << 30)) + .add_attr_option("max_workspace_size", Integer(1 << 30)) // If true, allows TensorRT to automatically convert float32 operations to float16. Must also be // enabled if any float16 operations are in the model. Note that TensorRT may still choose a // higher-precision kernel if it results in overall lower runtime, or if no low-precision // implementation exists. Default false. - .add_attr_option("use_fp16", runtime::Bool(false)) + .add_attr_option("use_fp16", Bool(false)) // If true, allows TensorRT to automatically convert float32 operations to uint8 // (aka quantized). Default false. - .add_attr_option("use_uint8", runtime::Bool(false)); + .add_attr_option("use_uint8", Bool(false)); } // namespace tensorrt } // namespace contrib diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc index 0499c0bba198..244f243749c1 100644 --- a/src/relay/backend/contrib/uma/targets.cc +++ b/src/relay/backend/contrib/uma/targets.cc @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") .add_attr_option("model") .add_attr_option>("libs") .add_attr_option("host") - .add_attr_option("from_device") + .add_attr_option("from_device") .set_attr( attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name)) .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); @@ -75,9 +75,8 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") } if (default_value->IsInstance()) { target_kind.add_attr_option(option_name, Downcast(default_value)); - } else if (default_value->IsInstance()) { - target_kind.add_attr_option(option_name, - Downcast(default_value)); + } else if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, Downcast(default_value)); } else { LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. " << "Given attribute option type: " << attr_option.second->GetTypeKey(); diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 66feac4699e6..1d6caecb87ba 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -89,13 +89,13 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { /********** Register Executors and options **********/ TVM_REGISTER_EXECUTOR("aot") - .add_attr_option("link-params", runtime::Bool(true)) - .add_attr_option("unpacked-api") + .add_attr_option("link-params", Bool(true)) + .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constant-byte-alignment"); + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constant-byte-alignment"); -TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", runtime::Bool(false)); +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); /********** Registry **********/ diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 0534298ea44d..923c9b2d5f65 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3e86e1c8eaf9..0c0ff7290115 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -73,42 +73,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp } bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { - // Unwrapping arrays may find user-provided FFI types in the - // attributes (e.g. Defining pad_value as ((0,0), (0,0)) will result - // in runtime::Int. These need to be converted to compile-time IR - // types when encountered. - if (lhs->IsInstance() || - lhs->IsInstance() || - lhs->IsInstance()) { - TVMRetValue lhs_convert; - lhs_convert = lhs; - PrimExpr lhs_expr = lhs_convert; - return MatchRetValue(lhs_expr, rhs); - } - - // StructuralEqual doesn't check for conversions between FFI types - // and IR types, but the pattern-matcher should. Therefore, - // explicitly recurse into the array. - if (auto opt_lhs_array = lhs.as>()) { - if (Optional> opt_rhs_array = rhs) { - Array lhs_array = opt_lhs_array.value(); - Array rhs_array = opt_rhs_array.value(); - if (lhs_array.size() != rhs_array.size()) { - return false; - } - for (size_t i = 0; i < lhs_array.size(); i++) { - TVMRetValue rhs_item; - rhs_item = rhs_array[i]; - if (!MatchRetValue(lhs_array[i], rhs_item)) { - return false; - } - } - return true; - } else { - return false; - } - } - switch (rhs.type_code()) { case kDLInt: if (auto* val = lhs.as()) { diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 222aba4bd25b..50d8531c7dd0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -79,7 +79,7 @@ Expr MakeReshape(Expr data, Array newshape, bool allowzero = false); Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end); -Expr MakeSplit(Expr data, Variant> indices_or_sections, int axis); +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 96f833d80505..fde6daa4d851 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2984,10 +2984,10 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, Layout ret = Layout::Undef(); size_t size = 0; - if (const auto* sections = param->indices_or_sections.as()) { + if (const IntImmNode* sections = param->indices_or_sections.as()) { size = sections->value; } else { - size = Downcast>(param->indices_or_sections).size() + 1; + size = Downcast>(param->indices_or_sections).size() + 1; } // If new_in_layouts are defined, this code tries to modify the layout. @@ -2998,12 +2998,13 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, param->axis = new_index; int factor = new_in_layouts[0].FactorOf(sp_dim); if (factor > 1) { - if (!param->indices_or_sections.as()) { - auto ios = Downcast>(param->indices_or_sections); - Array new_ios; + if (!param->indices_or_sections.as()) { + auto ios = Downcast>(param->indices_or_sections); + Array new_ios; for (const auto& v : ios) { - new_ios.push_back(runtime::Int(v->value / factor)); - if (v->value % factor) { + const IntImmNode* vint = v.as(); + new_ios.push_back(vint->value / factor); + if (vint->value % factor) { divisible = false; } } @@ -3040,7 +3041,7 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; ICHECK_GE(axis, 0) << "axis should be within the input dimension range."; - if (const auto* sections = param->indices_or_sections.as()) { + if (const IntImmNode* sections = param->indices_or_sections.as()) { if (!data->shape[axis].as()) { ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == tir::make_zero(DataType::Int(64)))) @@ -3060,8 +3061,8 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TupleType(Array(fields))); } else { Array indices; - for (auto index : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), index->value)); + for (auto i : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), i.as()->value)); } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; @@ -3096,20 +3097,19 @@ Array SplitCompute(const Attrs& attrs, const Array& inpu const auto param = attrs.as(); ICHECK(param != nullptr); - if (const auto* sections = param->indices_or_sections.as()) { + if (const IntImmNode* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { Array indices; - for (auto index : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), index->value)); + for (auto i : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), i.as()->value)); } return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, Variant> indices_or_sections, - int axis) { +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -3117,7 +3117,17 @@ Expr MakeSplit(Expr data, Variant> indices_or_ return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split").set_body_typed(MakeSplit); +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { + if (args.type_codes[1] == kDLInt) { + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. + *rv = + MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } +}); RELAY_REGISTER_OP("split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. @@ -4147,13 +4157,11 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - if (exclusive.defined()) { - attrs->exclusive = exclusive.value(); - } + attrs->exclusive = exclusive; static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 74827f166b51..a41e1e0d6674 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -159,7 +159,7 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; - auto split = MakeSplit(data, runtime::Int(branches.size()), 0); + auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { auto split_data = TupleGetItem(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index df28506c6217..34f986b251a2 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -266,7 +266,7 @@ class ConstantFolder : public MixedModeMutator { // always use graph executor with no link-params dict.Set(tvm::attr::kExecutor, - relay::Executor::Create("graph", {{"link-params", runtime::Bool(false)}})); + relay::Executor::Create("graph", {{"link-params", Bool(false)}})); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index da7a8f6420cd..edf1e4c99f4d 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -36,6 +36,8 @@ namespace tvm { namespace relay { +using namespace tvm::runtime; + /*! What is automatic differentiation(AD) and why is it important? * By AD, we roughly mean, given a term which denotes some mathematical function, * derive a term which denotes the derivative of that mathematical function. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 1112755b76a0..5026b1bcba79 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array> op_descriptor = + Array op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc deleted file mode 100644 index 9ab83a7b471c..000000000000 --- a/src/runtime/boxed_primitive.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/boxed_primitive.cc - * \brief Implementations of ObjectRef wrapper. - */ - -#include -#include - -namespace tvm { -namespace runtime { - -TVM_REGISTER_OBJECT_TYPE(BoxNode); -TVM_REGISTER_OBJECT_TYPE(BoxNode); -TVM_REGISTER_OBJECT_TYPE(BoxNode); - -/* \brief Allow explicit construction of Box - * - * Convert a `bool` to `Box`. For use in FFI handling, to - * provide an umambiguous representation between `bool(true)` and - * `int(1)`. Will be automatically unboxed in the case where a - * `Box` is provided to a PackedFunc that requires `int` input, - * mimicking C++'s default conversions. - * - * This is only needed for Box, as Box and Box - * can be converted in C++ as part of `TVMArgValue::operator - * ObjectRef()` without ambiguity, postponing conversions until - * required. - */ -TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); - -/* \brief Return the underlying boolean object. - * - * Used while unboxing a boolean return value during FFI handling. - * The return type is intentionally `int` and not `bool`, to avoid - * recursive unwrapping of boolean values. - * - * This is only needed for Box, as Box and Box - * can be unambiguously unboxed as part of - * `TVMRetValue::operator=(ObjectRef)`. - */ -TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { - return obj->value; -}); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 04d36ad8bcab..57979b160ea7 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -361,18 +361,14 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r TVMAPISetLastError("ModuleGetFunction expects second argument to be a string"); return kTvmErrorFunctionCallWrongArgType; } - - if (type_codes[2] == kDLInt) { - query_imports = args[2].v_int64 != 0; - } else if (type_codes[2] == kTVMArgBool) { - query_imports = args[2].v_bool; - } else { + if (type_codes[2] != kDLInt) { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; } mod = (TVMModuleHandle)args[0].v_handle; name = args[1].v_str; + query_imports = args[2].v_int64 != 0; to_return = TVMModGetFunction(mod, name, query_imports, &ret_value->v_handle); if (to_return == 0) { diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index f7204e372f6d..493bc3fb1dc9 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -102,10 +102,10 @@ DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { int cnt = 0; for (int i = 3; i < num_args; ++i) { int type_code = type_codes[i]; - if (type_code != kDLInt && type_code != kDLUInt && type_code != kTVMArgBool && - type_code != kDLFloat && type_code != kTVMDataType && type_code != kDLDevice && - type_code != kTVMOpaqueHandle && type_code != kTVMStr && type_code != kTVMNullptr && - type_code != kTVMBytes && type_code != kTVMObjectHandle) { + if (type_code != kDLInt && type_code != kDLUInt && type_code != kDLFloat && + type_code != kTVMDataType && type_code != kDLDevice && type_code != kTVMOpaqueHandle && + type_code != kTVMStr && type_code != kTVMNullptr && type_code != kTVMBytes && + type_code != kTVMObjectHandle) { os << "\n Argument #" << i - 3 << " has unsupported type code: " << type_code << " (" << ArgTypeCode2Str(type_code) << ")"; cnt += 1; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 485ebdb449da..d08dadb02bb9 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -325,10 +325,6 @@ struct RPCReference { channel->template Write(value.v_int64); break; } - case kTVMArgBool: { - channel->template Write(value.v_bool); - break; - } case kTVMDataType: { channel->Write(value.v_type); // padding @@ -436,10 +432,6 @@ struct RPCReference { channel->template Read(&(value.v_int64)); break; } - case kTVMArgBool: { - channel->template Read(&(value.v_bool)); - break; - } case kTVMDataType: { channel->Read(&(value.v_type)); int32_t padding = 0; diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 9fe6fba80f5c..3908ad1112a0 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -279,11 +279,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional err_ctx) { - if (arg.IsObjectRef()) { - ObjectRef obj = arg.AsObjectRef(); - LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype - << ", but received ObjectRef of type " << obj->GetTypeKey(); - } else if (dtype.is_bool()) { + if (dtype.is_bool()) { arg.operator bool(); } else if (dtype.is_int()) { arg.operator int64_t(); @@ -430,9 +426,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device") * \return Bool */ bool ReadIfCond(TVMArgValue cond) { - if (cond.type_code() == kDLInt || cond.type_code() == kTVMArgBool) { - return cond.operator bool(); - } + if (cond.type_code() == kDLInt) return cond.operator bool(); NDArray arr = cond.operator tvm::runtime::NDArray(); if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 61bdec680a29..54194e7e2a41 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -323,33 +323,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } else if (const auto* float_imm = value.as()) { // TODO(yelite): Make float number printing roundtrippable + output_.precision(17); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { output_ << '"' << float_imm->value << '"'; - } else if (std::nearbyint(float_imm->value) == float_imm->value) { - // Special case for floating-point values which would be - // formatted using %g, are not displayed in scientific - // notation, and whose fractional part is zero. - // - // By default, using `operator<<(std::ostream&, double)` - // delegates to the %g printf formatter. This strips off any - // trailing zeros, and also strips the decimal point if no - // trailing zeros are found. When parsed in python, due to the - // missing decimal point, this would incorrectly convert a float - // to an integer. Providing the `std::showpoint` modifier - // instead delegates to the %#g printf formatter. On its own, - // this resolves the round-trip errors, but also prevents the - // trailing zeros from being stripped off. - std::showpoint(output_); - std::fixed(output_); - output_.precision(1); - output_ << float_imm->value; } else { - std::defaultfloat(output_); - std::noshowpoint(output_); - output_.precision(17); output_ << float_imm->value; } - } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 686f486da6eb..ef68b89b5bf4 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -30,21 +30,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return LiteralDoc::Str(s, p); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](runtime::Bool obj, ObjectPath p, IRDocsifier d) -> Doc { - return LiteralDoc::Boolean(obj->value, p); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](runtime::Int obj, ObjectPath p, IRDocsifier d) -> Doc { - return LiteralDoc::Int(obj->value, p); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](runtime::Float obj, ObjectPath p, IRDocsifier d) -> Doc { - return LiteralDoc::Float(obj->value, p); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 35a9f35db491..6f9a8cbf8918 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -75,11 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - if (n->dtype.is_bool()) { - return LiteralDoc::Boolean(n->value, n_p); - } else { - return LiteralDoc::Int(n->value, n_p); - } + return LiteralDoc::Int(n->value, n_p); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/support/array.h b/src/support/array.h index 0d4c8134787b..0ca57a2410c5 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -164,14 +164,12 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; - ret_value = array; - Array as_int_vec = ret_value; - + inline std::vector operator()(const Array& vec) const { std::vector results; - for (const auto& value : as_int_vec) { - results.push_back(value->value); + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); } return results; } @@ -179,14 +177,12 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; - ret_value = array; - Array as_int_vec = ret_value; - + inline std::vector operator()(const Array& vec) const { std::vector results; - for (const auto& value : as_int_vec) { - results.push_back(value->value); + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); } return results; } @@ -195,13 +191,11 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; - ret_value = array; - Array as_int_vec = ret_value; - std::vector results; - for (const auto& value : as_int_vec) { - results.push_back(value->value); + for (const TSrcObjectRef& x : array) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); } return results; } @@ -227,10 +221,8 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (auto x : vec) { - TVMRetValue ret_value; - ret_value = x; - result.push_back(ret_value); + for (int x : vec) { + result.push_back(Integer(x)); } return result; } @@ -241,10 +233,8 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (auto x : vec) { - TVMRetValue ret_value; - ret_value = x; - result.push_back(ret_value); + for (int64_t x : vec) { + result.push_back(Integer(x)); } return result; } @@ -255,10 +245,8 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (auto x : vec) { - TVMRetValue ret_value; - ret_value = x; - result.push_back(ret_value); + for (double x : vec) { + result.push_back(FloatImm(tvm::DataType::Float(64), x)); } return result; } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 928cdfcab80b..aec57a1eb20d 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,58 +189,6 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); -TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); - -TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); - -TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { - return arg; -}); - -TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") - .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); - -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") - .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { - return map[key]; - }); - -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") - .set_body_typed([](Map map) -> ObjectRef { return map; }); - -TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { - return expr; -}); - -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") - .set_body_typed([](Array arr) -> ObjectRef { - for (ObjectRef item : arr) { - CHECK(item->IsInstance()) - << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; - } - return arr; - }); - -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") - .set_body_typed([](Array> arr) -> ObjectRef { - for (ObjectRef item : arr) { - CHECK(item->IsInstance() || item->IsInstance()) - << "Array contained " << item->GetTypeKey() - << " when it should contain either PrimExpr or PackedFunc"; - } - return arr; - }); - -TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") - .set_body_typed([](Map map) -> ObjectRef { - for (const auto& kv : map) { - ObjectRef value = kv.second; - CHECK(value->IsInstance()) - << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; - } - return map; - }); - /** * Simple event logger that can be used for testing purposes */ diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21899a12c4b0..481ba39cc7b1 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -347,26 +347,18 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); - if (t.is_bool()) { - // The stride between adjacent entries is still - // `sizeof(TVMValue)==64`, even if the enum currently holds a - // boolean. - buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(t_int64_, buf, index); - buf = builder_->CreatePointerCast(buf, DTypeToLLVMType(t)->getPointerTo()); - return TypedPointer(t_int8_, buf); - } else if (t.is_int() && t.bits() == 64) { + ICHECK(t.is_handle() || t.bits() == 64); + if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float() && t.bits() == 64) { + } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else if (t.is_handle()) { + } else { + ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); - } else { - LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMValue"; } } default: @@ -1374,16 +1366,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref.addr, t_void_p_); + } else { + return builder_->CreateLoad(ref.type, ref.addr); } - - llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); - - if (op->dtype == DataType::Bool()) { - struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); - } - - return struct_value; - } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 0406dcf951bb..dd5a3fb681ee 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -294,10 +294,10 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.MCOptions.ABIName = Downcast(target.Get("mabi")); } - auto maybe_level = target.Get("opt-level").as(); + auto maybe_level = Downcast(target.Get("opt-level")); #if TVM_LLVM_VERSION <= 170 if (maybe_level.defined()) { - int level = maybe_level.value()->value; + int level = maybe_level->value; if (level <= 0) { opt_level_ = llvm::CodeGenOpt::None; } else if (level == 1) { @@ -313,7 +313,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } #else if (maybe_level.defined()) { - int level = maybe_level.value()->value; + int level = maybe_level->value; if (level <= 0) { opt_level_ = llvm::CodeGenOptLevel::None; } else if (level == 1) { @@ -333,12 +333,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // Fast math options - auto GetBoolFlag = [&target](llvm::StringRef name) -> bool { - if (auto flag = target.Get(name.str())) { - return Downcast(flag); - } else { - return false; - } + auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { + return Downcast(target.Get(flag.str()).value_or(Bool(false))); }; if (GetBoolFlag("fast-math")) { #if TVM_LLVM_VERSION >= 60 diff --git a/src/target/tag.cc b/src/target/tag.cc index d45bf61a38f1..9eca3072df0e 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -76,61 +76,61 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", runtime::Int(4)}, + {"num-cores", Integer(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", runtime::Int(4)}}}}); + {"num-cores", Integer(4)}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", runtime::Int(49152)}, - {"max_threads_per_block", runtime::Int(1024)}, - {"thread_warp_size", runtime::Int(32)}, - {"registers_per_block", runtime::Int(65536)}, + {"max_shared_memory_per_block", Integer(49152)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"registers_per_block", Integer(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", runtime::Int(8)}}}}); + {"num-cores", Integer(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", runtime::Int(49152)}, - {"max_threads_per_block", runtime::Int(1024)}, - {"thread_warp_size", runtime::Int(32)}, - {"registers_per_block", runtime::Int(65536)}, + {"max_shared_memory_per_block", Integer(49152)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"registers_per_block", Integer(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", runtime::Int(6)}}}}); + {"num-cores", Integer(6)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", runtime::Int(49152)}, - {"max_threads_per_block", runtime::Int(1024)}, - {"thread_warp_size", runtime::Int(32)}, - {"registers_per_block", runtime::Int(65536)}, + {"max_shared_memory_per_block", Integer(49152)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"registers_per_block", Integer(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", runtime::Int(8)}}}}); + {"num-cores", Integer(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", runtime::Int(49152)}, - {"max_threads_per_block", runtime::Int(1024)}, - {"thread_warp_size", runtime::Int(32)}, - {"registers_per_block", runtime::Int(65536)}, + {"max_shared_memory_per_block", Integer(49152)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"registers_per_block", Integer(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", runtime::Int(12)}}}}); + {"num-cores", Integer(12)}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET @@ -139,10 +139,10 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"kind", String("cuda")}, \ {"keys", Array{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ - {"max_threads_per_block", runtime::Int(1024)}, \ - {"thread_warp_size", runtime::Int(32)}, \ - {"registers_per_block", runtime::Int(RegPerBlock)}, \ + {"max_shared_memory_per_block", Integer(SharedMem)}, \ + {"max_threads_per_block", Integer(1024)}, \ + {"thread_warp_size", Integer(32)}, \ + {"registers_per_block", Integer(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -158,9 +158,9 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", runtime::Int(41943040)); + .with_config("l2_cache_size_bytes", Integer(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) - .with_config("l2_cache_size_bytes", runtime::Int(52428800)); + .with_config("l2_cache_size_bytes", Integer(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -263,7 +263,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", runtime::Int(75497472)); + .with_config("l2_cache_size_bytes", Integer(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); @@ -416,7 +416,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", runtime::Int(Cores)}}); + {"num-cores", Integer(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -432,9 +432,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", runtime::Int(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ - {"thread_warp_size", runtime::Int(WarpSize)}, \ + {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", Integer(SharedMem)}, \ + {"thread_warp_size", Integer(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index a8337b58ae9b..cd2e3714e422 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,31 +359,24 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex() || - info.type_index == runtime::Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer or boolean + if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Mimic C++ automatic conversions, allowing bool to be used for - // integer parameters. + // Bool is a subclass of IntImm, so allow textual boolean values. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse integer from string: " + interp_str); + throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); } } - - if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - return runtime::Int(v); - } else { - return runtime::Bool(v); - } + return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -417,13 +410,13 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "runtime.BoxInt")); - } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { + return GetRef(ObjTypeCheck(obj, "Integer")); + } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -490,11 +483,7 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { - return std::to_string(p->value); - } else if (const auto* p = obj.as()) { - return std::to_string(p->value); - } else if (const auto* p = obj.as()) { + if (const auto* p = obj.as()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -505,7 +494,7 @@ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { } return u; } - LOG(FATAL) << "Cannot stringify object of type " << obj->GetTypeKey(); + LOG(FATAL) << "Cannot stringify this object"; } std::string TargetInternal::StringifyArray(const ArrayNode& array) { @@ -964,7 +953,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device"))->value; + int device_id = Downcast(attrs.at("from_device")).IntValue(); attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1017,13 +1006,38 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; + const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - // Delegate conversion from TVMRetValue to the FFI's default conversions. - if (Optional opt = ret) { - output[key] = opt.value(); + switch (ret.type_code()) { + case kTVMNullptr: + // Nothing returned for this parameter, move on to the next one. + continue; + + case kTVMArgInt: + if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + output[key] = Integer(static_cast(ret)); + } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + output[key] = Bool(static_cast(ret)); + } else { + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received integer from device api"; + } + break; + + case kTVMStr: + ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) + << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received string from device api"; + output[key] = String(ret.operator std::string()); + break; + + default: + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; + break; } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index fced74c3a559..708d3ccd7621 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -243,7 +243,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", runtime::Bool(true)}}; + Map features = {{"is_test", Bool(true)}}; target.Set("features", features); return target; } @@ -256,16 +256,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit @@ -273,7 +273,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -301,29 +301,28 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", runtime::Int(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", - runtime::Int(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", runtime::Int(1024)) - .add_attr_option("thread_warp_size", runtime::Int(32)) + .add_attr_option("max_num_threads", Integer(1024)) + .add_attr_option("thread_warp_size", Integer(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -333,24 +332,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", runtime::Int(256)) - .add_attr_option("max_threads_per_block", runtime::Int(256)) - .add_attr_option("max_shared_memory_per_block", runtime::Int(65536)) - .add_attr_option("thread_warp_size", runtime::Int(64)) + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(65536)) + .add_attr_option("thread_warp_size", Integer(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", runtime::Int(256)) - .add_attr_option("max_shared_memory_per_block", runtime::Int(16384)) - .add_attr_option("max_num_threads", runtime::Int(256)) - .add_attr_option("thread_warp_size", runtime::Int(1)) - .add_attr_option("texture_spatial_limit", runtime::Int(16384)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(16384)) + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size", Integer(1)) + .add_attr_option("texture_spatial_limit", Integer(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", runtime::Int(128)) + .add_attr_option("max_function_args", Integer(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -359,55 +358,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", runtime::Int(256)) - .add_attr_option("max_threads_per_block", runtime::Int(256)) - .add_attr_option("max_shared_memory_per_block", runtime::Int(32768)) - .add_attr_option("thread_warp_size", runtime::Int(16)) - .add_attr_option("max_function_args", runtime::Int(31)) + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(32768)) + .add_attr_option("thread_warp_size", Integer(16)) + .add_attr_option("max_function_args", Integer(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", runtime::Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", runtime::Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", Bool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", Bool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", runtime::Int(256)) - .add_attr_option("max_threads_per_block", runtime::Int(256)) - .add_attr_option("thread_warp_size", runtime::Int(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("thread_warp_size", Integer(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_num_threads", Integer(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -424,8 +423,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index fb839c28da96..5797d2295bab 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -56,25 +56,10 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode* op); -static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - const char* shared_text = - "When a TE compute node produces multiple outputs, " - "each of which is a reduction, " - "each reduction must be structurally identical, " - "except for the ReduceNode::value_index. "; - - StructuralEqual eq; - - ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " - << a->combiner << " does not match " << b->combiner; - ICHECK(a->source.same_as(b->source)) - << shared_text << "However, the input " << a->source << " does not match " << b->source; - ICHECK(eq(a->axis, b->axis)) << shared_text << "However, the reduction axis " << a->axis - << " does not match " << b->axis; - ICHECK(eq(a->condition, b->condition)) << shared_text << "However, the predicate " << a->condition - << " does not match " << b->condition; - ICHECK(eq(a->init, b->init)) << shared_text << "However, the initial value " << a->init - << " does not match " << b->init; +inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) && + ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); } int ComputeOpNode::num_outputs() const { return body.size(); } @@ -544,7 +529,8 @@ class ComputeVerifier final : protected tir::ExprVisitor { << "with being Reduce operation or not."; if (reduce && reduce_) { - AssertReduceEqual(reduce, reduce_); + ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } level_ = 0; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index b5a87d9446d8..2eb0693685a6 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -355,12 +355,11 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Array seq_stmt; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - StructuralEqual eq; - return eq(a->combiner, b->combiner) && // - eq(a->source, b->source) && // - eq(a->axis, b->axis) && // - eq(a->condition, b->condition) && // - eq(a->init, b->init); + return a->combiner.same_as(b->combiner) && // + a->source.same_as(b->source) && // + a->axis.same_as(b->axis) && // + a->condition.same_as(b->condition) && // + ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); }; PrimExpr expr_body = compute_op->body[0]; @@ -371,9 +370,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in const tir::ReduceNode* reduce_ = compute_op->body[k].as(); ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " - << "but the first argument has body " << GetRef(reduce_) << ", while the " << k - << "-th argument has body " << GetRef(reduce); + << "The Reduce inputs of ComputeOp should have the same attribute except value_index"; tensors.push_back(compute_op.output(k)); } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 774a0f8f1f89..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -63,17 +63,7 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") - .set_body_typed([](Variant> shape_arg, DataType dtype, - std::string name) { - auto shape = [&]() -> Array { - if (auto arg_expr = shape_arg.as()) { - return {arg_expr.value()}; - } else if (auto arg_array = shape_arg.as>()) { - return arg_array.value(); - } else { - LOG(FATAL) << "Variant did not contain either allowed type"; - } - }(); + .set_body_typed([](Array shape, DataType dtype, std::string name) { return placeholder(shape, dtype, name); }); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 1ad8914e48cc..c38c5a5c800b 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -124,10 +124,9 @@ void ReplaceDataFlow(const Array& stages, std::unordered_mapcombiner, b->combiner) && struct_equal(a->source, b->source) && - struct_equal(a->axis, b->axis) && struct_equal(a->condition, b->condition) && - struct_equal(a->init, b->init); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && + ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); } Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 70e82a605369..3a41c5ac5a25 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -134,7 +134,7 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); if (target.defined() && target->kind->name == "hexagon") { - auto value = target->GetAttr("vtcm-capacity").value()->value; + auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; if (value > 0) return value; } return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index c38237a664f7..1506082003fd 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -35,18 +35,6 @@ namespace tvm { namespace tir { -/* \brief Convert an object to a PrimExpr - * - * All conversions to a PrimExpr are performed as part of the FFI, - * when calling a function that accepts a PrimExpr as an argument. If - * a function must normalize to a PrimExpr (e.g. before accessing the - * `expr.dtype` field), this function allows the FFI conversions to be - * explicitly invoked. - */ -TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { - return expr; -}); - #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -558,9 +546,7 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, - Array> args, - Span span) { + .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || @@ -721,11 +707,9 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis if (!init.empty()) { ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; for (size_t i = 0; i < init.size(); i++) { - ICHECK(init[i].defined()) << "Init value must be defined"; ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || init[i]->IsInstance()) - << "init can only be a IntImm, FloatImm or ProducerLoad, " - << "but received " << init[i] << " of type " << init[i]->GetTypeKey(); + << "init can only be a IntImm, FloatImm or ProducerLoad"; } } n->dtype = source[value_index].dtype(); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 2c94b9d8646b..14dd0eadb65c 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -27,8 +27,6 @@ #include #include -#include "utils.h" - namespace tvm { namespace tir { namespace { @@ -81,11 +79,6 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, if (!ret_type.defined()) { ret_type = VoidType(); } - - if (attrs.defined()) { - attrs = Downcast(NormalizeAttributeObject(attrs)); - } - auto n = make_object(); n->params = std::move(params); n->body = std::move(body); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 78fb9365cc71..b30d0caf6af3 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map>& param_map) { +PrimFunc Specialize(PrimFunc func, const Map& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 9c8f580b5413..5df76450ff1e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -27,7 +27,6 @@ #include #include "buffer_common.h" -#include "utils.h" namespace tvm { namespace tir { @@ -62,15 +61,6 @@ TVM_REGISTER_NODE_TYPE(LetStmtNode); // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - // The nodes are not required to be a TIR type, and may legally - // contain any ObjectRef. However, normalizing to an IR type if - // possible prevents spurious discrepancies in StructuralEqual(). - if (auto opt = node.as()) { - node = Bool(opt.value()); - } else if (auto opt = node.as()) { - node = Integer(opt.value()); - } - auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -119,21 +109,13 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding, Map annotations, Span span) { - ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); + ICHECK(min.dtype().is_scalar()); + ICHECK(extent.dtype().is_scalar()); + ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); - auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) { - auto dtype = expr.dtype(); - CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) - << "TIR For nodes require a scalar integer as the " << field_name << ", but received " - << expr << " with dtype " << dtype; - }; - require_scalar_int_dtype(loop_var, "loop_var"); - require_scalar_int_dtype(min, "min"); - require_scalar_int_dtype(extent, "extent"); - // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them // without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { @@ -154,8 +136,6 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); - annotations = Downcast>(NormalizeAttributeObject(annotations)); - ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); @@ -254,8 +234,6 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); - annotations = Downcast>(NormalizeAttributeObject(annotations)); - ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -310,8 +288,6 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); - annotations = Downcast>(NormalizeAttributeObject(annotations)); - ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -676,8 +652,6 @@ Block::Block(Array iter_vars, Array reads, Array init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { - annotations = Downcast>(NormalizeAttributeObject(annotations)); - ObjectPtr node = make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc deleted file mode 100644 index 0e3dc1237894..000000000000 --- a/src/tir/ir/utils.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/tir/ir/utils.cc - * \brief Utilities for manipulating TIR - */ -#include "utils.h" - -#include - -namespace tvm { -namespace tir { - -ObjectRef NormalizeAttributeObject(ObjectRef obj) { - if (const auto* runtime_int = obj.as()) { - return Integer(runtime_int->value); - } else if (const auto* runtime_bool = obj.as()) { - return Bool(runtime_bool->value); - } else if (const auto* runtime_float = obj.as()) { - return FloatImm(DataType::Float(32), runtime_float->value); - } else if (auto opt_array = obj.as>()) { - return opt_array.value().Map(NormalizeAttributeObject); - } else if (auto opt_map = obj.as>()) { - Map new_map; - bool is_same = true; - - for (const auto& [key, obj] : opt_map.value()) { - ObjectRef new_obj = NormalizeAttributeObject(obj); - is_same = is_same && obj.same_as(new_obj); - new_map.Set(key, new_obj); - } - - if (is_same) { - return obj; - } else { - return new_map; - } - } else if (auto dict_attrs = obj.as()) { - auto new_attrs = Downcast>(NormalizeAttributeObject(dict_attrs->dict)); - if (new_attrs.same_as(dict_attrs->dict)) { - return GetRef(dict_attrs); - } else { - return DictAttrs(new_attrs); - } - } else { - return obj; - } -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h deleted file mode 100644 index b1f7a722899f..000000000000 --- a/src/tir/ir/utils.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/ir/utils.h - * \brief Utilities for manipulating TIR - */ -#ifndef TVM_TIR_IR_UTILS_H_ -#define TVM_TIR_IR_UTILS_H_ - -#include - -namespace tvm { -namespace tir { - -/* \brief Normalize an ObjectRef held - * - * Where possible, the IR should be normalized contain IR types. For - * example, holding a `tir::IntImm` instead of a `runtime::Int`. In - * attributes, this is not always possible, as attributes may refer to - * non-IR objects. - * - * This function normalizes any `runtime::Int`, `runtime::Bool`, - * `runtime::Float`, or containers of those types to the corresponding - * IR type. - * - * \param obj The attribute object to be normalized - * - * \returns The normalized attribute - */ -ObjectRef NormalizeAttributeObject(ObjectRef obj); - -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_IR_UTILS_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index dad4ea98d614..c79a148e4b6e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -229,12 +229,9 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } PrimExpr ret(PrimExpr value, Span span) { - CHECK(value.defined()); return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } -TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); - // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; @@ -1051,15 +1048,12 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { - if (auto opt = args[0].TryAsInt()) { - *ret = tir::make_const(args[1], opt.value(), args[2]); - } else if (auto opt = args[0].TryAsBool()) { - *ret = tir::make_const(args[1], opt.value(), args[2]); - } else if (auto opt = args[0].TryAsFloat()) { - *ret = tir::make_const(args[1], opt.value(), args[2]); + if (args[0].type_code() == kDLInt) { + *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]); + } else if (args[0].type_code() == kDLFloat) { + *ret = tir::make_const(args[1], args[0].operator double(), args[2]); } else { - LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " - << "but instead received argument with type code " << args[0].type_code(); // FIXME + LOG(FATAL) << "only accept int or float"; // FIXME } }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 73b5ff3fafd4..cda501cd992e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -233,9 +233,9 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); @@ -914,14 +914,6 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ if (ann_val.as()) { return ann_val; } - if (auto* runtime_int = ann_val.as()) { - return IntImm(DataType::Int(32), runtime_int->value); - } else if (auto* runtime_float = ann_val.as()) { - return FloatImm(DataType::Float(32), runtime_float->value); - } else if (auto* runtime_bool = ann_val.as()) { - return Bool(runtime_bool->value); - } - if (const auto* expr = ann_val.as()) { ICHECK(!ann_val->IsInstance()) << "TypeError: runtime::String is expected, but gets StringImm"; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 092bcf0c79f9..4eccff10a2c7 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -87,9 +87,8 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision = NullOpt) override; + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 9209e6578687..122c5ff0d9fe 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -439,11 +439,6 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os } else if (const auto* float_imm = obj.as()) { os.precision(17); os << float_imm->value; - } else if (const auto* runtime_int = obj.as()) { - os << runtime_int->value; - } else if (const auto* runtime_float = obj.as()) { - os.precision(17); - os << runtime_float->value; } else if (const auto* array = obj.as()) { os << '['; bool is_first = true; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fd1349e4a3ec..fe1c1850dcd5 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,9 +55,8 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, - const Array& probs, - Optional* decision); + const Array& candidates, const Array& probs, + Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 4c7b208e964f..92c3423bcbbb 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../ir/utils.h" #include "../utils.h" namespace tvm { @@ -98,8 +97,6 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { - ann_val = NormalizeAttributeObject(ann_val); - if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 8e16f50b8b95..2a2f17355ca6 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,18 +163,19 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const Array& candidates, const Array& probs, + Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { - i = decision->value()->value; + const auto* int_imm = decision->as(); + i = int_imm->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - std::vector weights = support::AsVector(probs); + std::vector weights = support::AsVector(probs); std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); @@ -182,8 +183,8 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } - *decision = runtime::Int(i); // decision is guaranteed not to be nullptr. - return candidates[i]->value; + *decision = Integer(i); // decision is guaranteed not to be nullptr. + return candidates[i].IntValue(); } std::function MakeMultinomialSampler( @@ -460,11 +461,24 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { - return sch->SampleCategorical(candidates, probs, decision); + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + Array probs_float = probs.Map([](const ObjectRef& prob) { + const auto* prob_float = prob.as(); + if (prob_float != nullptr) { + return GetRef(prob_float); + } + const auto* prob_int = prob.as(); + if (prob_int != nullptr) { + return FloatImm(DataType::Float(32), static_cast(prob_int->value)); + } + LOG(FATAL) + << "SampleCategorical does not accept probability with type other than float or int."; + throw; + }); + return sch->SampleCategorical(candidates, probs_float, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 6e243bf19198..4b10df7e9728 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -112,9 +112,7 @@ Array TranslateInputRVs( } else if (const auto* str_obj = input.as()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); - } else if (input->IsInstance() || input->IsInstance() || - input->IsInstance() || - input->IsInstance()) { + } else if (input->IsInstance() || input->IsInstance()) { // Case 3. integer or floating-point number results.push_back(input); } else if (input->IsInstance()) { @@ -151,9 +149,7 @@ Array TranslateInputRVs(const Array& inputs, results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { // Case 3. integer or floating-point number - if (input->IsInstance() || input->IsInstance() || - input->IsInstance() || - input->IsInstance()) { + if (input->IsInstance() || input->IsInstance()) { results.push_back(input); continue; } @@ -392,9 +388,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { try { const ArrayNode* arr = decision_entry.as(); ICHECK(arr && arr->size() == 2); - auto arr0 = arr->at(0).as(); + const IntImmNode* arr0 = arr->at(0).as(); ICHECK(arr0); - index = arr0.value(); + index = arr0->value; decision = arr->at(1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1611109d7735..16c4350aaee6 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 78629e84f039..686d84ebc6fe 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,9 +47,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision = NullOpt) final; + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index 14672f568549..cc33ba9f86c2 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map> param_map; + Map param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 2948773321dd..423b0ca92237 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,7 +155,6 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; - if (t.is_bool()) return DataType::Bool(); if (t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1cde4f2ebe7d..1a3888a7cd48 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -511,8 +511,6 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } else if (IsArrayHandle(arg)) { arg_tcode = kTVMDLTensorHandle; - } else if (arg.dtype().is_bool()) { - arg_tcode = kTVMArgBool; } // opaque handle need to set the kind properly if (arg_tcode == kTVMOpaqueHandle) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9f2f1295fece..d327cdfa8393 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -263,15 +263,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](DataType arg_type, int i) { + auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version - DataType api_type = APIType(arg_type); + DataType api_type = APIType(t); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. - if (api_type != arg_type) { - res = Cast(arg_type, res); + if (api_type != t) { + res = Cast(t, res); } return res; }; @@ -319,7 +319,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { continue; } - PrimExpr arg_value; + var_def.emplace_back(f_arg_value(param.dtype(), i), param); + if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); + } // type code checks Var tcode(param->name_hint + ".code", DataType::Int(32)); @@ -332,45 +335,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImm(msg.str()), nop)); - - arg_value = f_arg_value(param.dtype(), i); - } else if (t.is_bool()) { - std::ostringstream msg; - msg << name_hint << ": Expect arg[" << i << "] to be boolean"; - seq_init.emplace_back( - AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); - - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgBool, - f_arg_value(DataType::Bool(), i), - cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), - }); - } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_init.emplace_back( - AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); - - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgInt, - f_arg_value(t, i), - cast(t, f_arg_value(DataType::Bool(), i)), - }); + seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { ICHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); - - arg_value = f_arg_value(param.dtype(), i); - } - - var_def.emplace_back(arg_value, param); - if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc index adabb9b9b6cf..53ea7e39ed59 100644 --- a/tests/cpp/relay/backend/runtime_test.cc +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -26,13 +26,13 @@ namespace tvm { namespace relay { TVM_REGISTER_RUNTIME("TestRuntime") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") .add_attr_option("another_option") - .add_attr_option("defaulty_the_default_option", runtime::Bool(false)); + .add_attr_option("defaulty_the_default_option", Bool(false)); TEST(Runtime, Create) { - Map attrs = {{"my_bool", runtime::Bool(true)}}; + Map attrs = {{"my_bool", Bool(true)}}; Runtime my_runtime = Runtime::Create("TestRuntime", attrs); ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); @@ -40,7 +40,7 @@ TEST(Runtime, Create) { } TEST(Runtime, UnknownAttr) { - Map attrs = {{"woofles", runtime::Bool(true)}}; + Map attrs = {{"woofles", Bool(true)}}; ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); } @@ -64,7 +64,7 @@ TEST(RuntimeRegistry, ListRuntimeOptions) { Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); ICHECK_EQ(attrs.empty(), false); - ICHECK_EQ(attrs["my_bool"], "runtime.BoxBool"); + ICHECK_EQ(attrs["my_bool"], "IntImm"); ICHECK_EQ(attrs["your_names"], "Array"); ICHECK_EQ(attrs["another_option"], "runtime.String"); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 0a2b8206d322..2db4b572bf60 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,15 +32,15 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { String mcpu = Downcast(target.at("mcpu")); target.Set("mcpu", String("super_") + mcpu); target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", runtime::Bool(true)}}); + target.Set("features", Map{{"test", Bool(true)}}); return target; } @@ -76,14 +76,14 @@ TEST(TargetKind, GetAttrMap) { TEST(TargetCreation, NestedConfig) { Map config = { - {"my_bool", runtime::Bool(true)}, + {"my_bool", Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", runtime::Int(1)}, - {"b", runtime::Int(2)}, + Map{ + {"a", 1}, + {"b", 2}, }, }, }; @@ -91,14 +91,13 @@ TEST(TargetCreation, NestedConfig) { ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); ICHECK_EQ(target->tag, ""); ICHECK(target->keys.empty()); - runtime::Bool my_bool = target->GetAttr("my_bool").value(); + Bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool.operator bool(), true); Array your_names = target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = - target->GetAttr>("her_maps").value(); + Map her_maps = target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); @@ -106,15 +105,15 @@ TEST(TargetCreation, NestedConfig) { TEST(TargetCreationFail, UnrecognizedConfigOption) { Map config = { - {"my_bool", runtime::Bool(true)}, + {"my_bool", Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ - {"a", runtime::Int(1)}, - {"b", runtime::Int(2)}, + Map{ + {"a", 1}, + {"b", 2}, }, }, }; @@ -134,9 +133,9 @@ TEST(TargetCreationFail, TypeMismatch) { {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", runtime::Int(1)}, - {"b", runtime::Int(2)}, + Map{ + {"a", 1}, + {"b", 2}, }, }, }; @@ -151,13 +150,13 @@ TEST(TargetCreationFail, TypeMismatch) { TEST(TargetCreationFail, TargetKindNotFound) { Map config = { - {"my_bool", runtime::Bool("true")}, + {"my_bool", Bool("true")}, {"your_names", Array{"junru", "jian"}}, { "her_maps", - Map{ - {"a", runtime::Int(1)}, - {"b", runtime::Int(2)}, + Map{ + {"a", 1}, + {"b", 2}, }, }, }; @@ -179,16 +178,15 @@ TEST(TargetCreation, TargetParser) { TEST(TargetCreation, TargetFeatures) { Target test_target_with_parser("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); + ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); Target test_target_no_parser("TestTargetKind"); - ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); - ASSERT_EQ(test_target_no_parser->GetFeature("test", runtime::Bool(true)).value(), - true); + ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); + ASSERT_EQ(test_target_no_parser->GetFeature("test", Bool(true)).value(), true); } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", runtime::Bool(true)}}; + Map features = {{"test", Bool(true)}}; Map config = { {"kind", String("TestTargetParser")}, {"mcpu", String("woof")}, @@ -471,13 +469,13 @@ TEST(TargetCreation, DetectSystemTriple) { #endif TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) - .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) .set_attr(tvm::attr::kRelayToTIR, diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index f5b1651e115a..bbfb8bd2db12 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -15,14 +15,10 @@ # specific language governing permissions and limitations # under the License. """Test packed function FFI.""" -import gc - -import numpy as np - import tvm from tvm import te import tvm.testing -from tvm.script import tir as T +import numpy as np def test_get_global(): @@ -41,7 +37,7 @@ def my_packed_func(*args): def test_get_callback_with_node(): - x = T.int32(10) + x = tvm.runtime.convert(10) def test(y): assert y.handle != x.handle @@ -70,7 +66,7 @@ def add(x): myf = tvm.runtime.convert(addy) f = myf(10) - assert f(11) == 21 + assert f(11).value == 21 def test_convert(): @@ -117,14 +113,6 @@ def test_device_func(dev): def test_rvalue_ref(): def callback(x, expected_count): - # The use count of TVM objects is decremented as part of - # `ObjectRef.__del__`, which runs when the Python object is - # destructed. However, Python object destruction is not - # deterministic, and even CPython's reference-counting is - # considered an implementation detail. Therefore, to ensure - # correct results from this test, `gc.collect()` must be - # explicitly called. - gc.collect() assert expected_count == tvm.testing.object_use_count(x) return x diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 42f5b0ccd0b8..afd716cde389 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -16,27 +16,16 @@ # under the License. import tvm import tvm.testing -from tvm import te, tir -from tvm.script import tir as T +from tvm import te class CanonicalChecker: def __init__(self): self.analyzer = tvm.arith.Analyzer() - def _convert(self, expr): - # TODO(Lunderberg): Make utility functions `tir.convert` and - # `relax.convert` that convert to their respective IR types. - # Implementation should be in C++, and should only consist of - # conversions that are applied automatically through FFI. - if isinstance(expr, int): - return T.int32(expr) - else: - return expr - def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - expected = self._convert(expected) + expected = tvm.runtime.convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected ) @@ -388,13 +377,13 @@ def test_simplify_normalize_min_value_expr(): x = te.var("x", "int32") ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) - ck.verify(te.min_value("int32") + x == 0, tir.const(False)) + ck.verify(te.min_value("int32") + x == 0, False) ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) - ck.verify(0 == te.min_value("int32") + x, tir.const(False)) + ck.verify(0 == te.min_value("int32") + x, False) ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) - ck.verify(x + te.min_value("int32") == 0, tir.const(False)) + ck.verify(x + te.min_value("int32") == 0, False) ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) - ck.verify(0 == x + te.min_value("int32"), tir.const(False)) + ck.verify(0 == x + te.min_value("int32"), False) def test_proddiv_simplify(): diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index f0e6f05adfad..3a10ec05efeb 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -17,7 +17,6 @@ import tvm import tvm.testing from tvm.tir import floordiv, floormod -from tvm.script import tir as T def ifuse(inputs, pred_extent=None): @@ -538,7 +537,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) - tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) + tvm.ir.assert_structural_equal(res[1][1], True) # compound 1 i0 = create_iter("i0", 4) @@ -554,7 +553,7 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) @@ -570,7 +569,7 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -588,11 +587,11 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) - tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) + tvm.ir.assert_structural_equal(res[2][1], True) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 @@ -607,9 +606,9 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) + tvm.ir.assert_structural_equal(res[2][0], True) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -643,10 +642,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices @@ -662,9 +661,9 @@ def test_subspace_division(): assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) - tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map( @@ -691,10 +690,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) @@ -736,8 +735,8 @@ def test_subspace_divide_trivial_iters(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], x) - tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) - tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], y) diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py index 0aa353c60041..d38fe70f6b5c 100644 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ b/tests/python/arith/test_arith_narrow_predicate_expression.py @@ -20,7 +20,6 @@ from tvm import tir from tvm.runtime import convert -from tvm.script import tir as T i = tir.Var("i", "int32") @@ -43,18 +42,18 @@ [i < n, i < 0], [i <= n, i <= 0], [i >= n, i >= 7], - [n > i, T.int32(0) > i], - [n < i, T.int32(7) < i], - [n <= i, T.int32(7) <= i], - [n >= i, T.int32(0) >= i], - [i == n, tir.all(i <= 0, T.int32(7) <= i)], - [n == i, tir.all(T.int32(7) <= i, i <= 0)], - [i != n, tir.any(i < 0, T.int32(7) < i)], - [n != i, tir.any(T.int32(7) < i, i < 0)], + [n > i, convert(0) > i], + [n < i, convert(7) < i], + [n <= i, convert(7) <= i], + [n >= i, convert(0) >= i], + [i == n, tir.all(i <= 0, convert(7) <= i)], + [n == i, tir.all(convert(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, convert(7) < i)], + [n != i, tir.any(convert(7) < i, i < 0)], [i // 4 > n, i // 4 > 7], - [n < i // 4, T.int32(7) < i // 4], + [n < i // 4, convert(7) < i // 4], [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, T.int32(0) <= tir.Add(i, 0) // 4)], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], [i + n < 10, i + 7 < 10], [i - n < 10, tir.Sub(i, 0) < 10], [tir.Not(i < n), tir.Not(i < 7)], diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 7fc1862192d6..90f0aeef47d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -27,8 +27,6 @@ from tvm.tir import truncdiv as tdiv from tvm.tir import truncmod as tmod -from tvm.script import tir as T - class TestCase: def __init__(self, before, expected, preconditions=None): @@ -37,21 +35,10 @@ def __init__(self, before, expected, preconditions=None): if isinstance(expected, tir.expr.EqualOp): expected = expected.asobject() - self.before = self._convert(before) - self.expected = self._convert(expected) + self.before = before + self.expected = expected self.preconditions = preconditions - @staticmethod - def _convert(expr): - if isinstance(expr, tir.expr.EqualOp): - return expr.asobject() - elif isinstance(expr, int): - return T.int32(expr) - elif isinstance(expr, float): - return T.float32(expr) - else: - return expr - @property def constraint(self): if self.preconditions is None: @@ -1021,8 +1008,8 @@ class TestComparisons(BaseCompare): TestCase(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), # Rewrite based on definition of integer division - TestCase(tir.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), - TestCase(tir.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), + TestCase(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), + TestCase(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5)), # Narrow upper bound using floormod TestCase(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), TestCase(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), @@ -1038,36 +1025,36 @@ class TestComparisons(BaseCompare): # Merge a known floordiv and an upper bound of floormod into a value range TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) < 7), - tir.all(T.int32(50) <= x, x < 57), + tir.all(tvm.runtime.convert(50) <= x, x < 57), ), TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), - tir.all(T.int32(50) <= x, x <= 57), + tir.all(tvm.runtime.convert(50) <= x, x <= 57), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) < 7), - tir.all(T.int32(-50) <= x, x < -43), + tir.all(tvm.runtime.convert(-50) <= x, x < -43), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), - tir.all(T.int32(-50) <= x, x <= -43), + tir.all(tvm.runtime.convert(-50) <= x, x <= -43), ), # Merge a known floordiv and an lower bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), - tir.all(T.int32(57) < x, x < 60), + tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)), + tir.all(tvm.runtime.convert(57) < x, x < 60), ), TestCase( - tir.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), - tir.all(T.int32(57) <= x, x < 60), + tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)), + tir.all(tvm.runtime.convert(57) <= x, x < 60), ), TestCase( - tir.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), - tir.all(T.int32(-43) < x, x < -40), + tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)), + tir.all(tvm.runtime.convert(-43) < x, x < -40), ), TestCase( - tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), - tir.all(T.int32(-43) <= x, x < -40), + tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)), + tir.all(tvm.runtime.convert(-43) <= x, x < -40), ), TestCase(tvm.te.min(x, 11) < 10, x < 10), TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")), @@ -1237,16 +1224,14 @@ class TestIfThenElse(BaseCompare): class TestCLZ(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), T.int32(32)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), T.int32(31)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), T.int32(30)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), T.int32(24)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), T.int32(64)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), T.int32(63)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), T.int32(62)), - TestCase( - tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), T.int32(56) - ), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56), ) diff --git a/tests/python/arith/test_arith_solve_linear_equations.py b/tests/python/arith/test_arith_solve_linear_equations.py index 3195a4ae514f..24eb860c55f6 100644 --- a/tests/python/arith/test_arith_solve_linear_equations.py +++ b/tests/python/arith/test_arith_solve_linear_equations.py @@ -19,7 +19,6 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing -from tvm.script import tir as T def test_solution_consistency(): @@ -110,8 +109,8 @@ def test_unique_solution(): [x, y], ) assert list(solution.dst.variables) == [] - assert ir.structural_equal(solution.src_to_dst[x], T.int32(15)) - assert ir.structural_equal(solution.src_to_dst[y], T.int32(5)) + assert ir.structural_equal(solution.src_to_dst[x], 15) + assert ir.structural_equal(solution.src_to_dst[y], 5) def test_low_rank(): @@ -129,7 +128,7 @@ def test_low_rank(): [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) - assert ir.structural_equal(solution.src_to_dst[z], T.int32(5)) + assert ir.structural_equal(solution.src_to_dst[z], 5) def test_infer_range(): @@ -150,12 +149,12 @@ def test_infer_range(): assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range - assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9)) - assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) + assert ir.structural_equal(solution.dst.ranges[n0].min, -9) + assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) - assert ir.structural_equal(ineq.a, T.int32(-5)) + assert ir.structural_equal(ineq.a, -5) assert ir.structural_equal(ineq.b, n0) @@ -173,7 +172,7 @@ def test_ill_formed(): ) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - ir.assert_structural_equal(rel, tir.const(False)) + assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py b/tests/python/arith/test_arith_solve_linear_inequality.py index 664258ae7cf1..5285da12e75d 100644 --- a/tests/python/arith/test_arith_solve_linear_inequality.py +++ b/tests/python/arith/test_arith_solve_linear_inequality.py @@ -19,7 +19,6 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing -from tvm.script import tir as T @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11458") @@ -114,10 +113,10 @@ def test_dual_variable(): [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].min, T.int32(0)) - assert ir.structural_equal(solution.dst.ranges[x_new].extent, T.int32(11)) - assert ir.structural_equal(solution.dst.ranges[y_new].min, T.int32(0)) - assert ir.structural_equal(solution.dst.ranges[y_new].extent, T.int32(6)) + assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) + assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) @@ -186,7 +185,7 @@ def test_no_solution(): solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - ir.assert_structural_equal(rel, tir.const(False)) + assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 112d1151febd..112c521d06d4 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -769,7 +769,7 @@ def check_cuda(dtype, n, l, padding, lanes): (n // lanes, l + 2 * padding, lanes), lambda i, j, k: tvm.te.if_then_else( tvm.te.any(j < padding, j >= l + padding), - tvm.tir.const(0, dtype), + tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding], ), name="B", diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index d9a6fd6e62d1..f50d63878e4f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1138,46 +1138,5 @@ def func(): tvm.build(func) -def test_int_parameter(): - """Boolean may be passed to functions accepting int""" - - @T.prim_func - def func(arg: T.int32) -> T.int32: - T.func_attr({"target": T.target("llvm")}) - if arg > 0: - return 10 - else: - return 20 - - built = tvm.build(func) - output = built(True) - assert output == 10 - - output = built(False) - assert output == 20 - - -def test_bool_parameter(): - """Integers may be passed to functions accepting bool""" - - @T.prim_func - def func(arg: T.bool) -> T.int32: - T.func_attr({"target": T.target("llvm")}) - if arg: - return 10 - else: - return 20 - - built = tvm.build(func) - output = built(1) - assert output == 10 - - output = built(2) - assert output == 10 - - output = built(0) - assert output == 20 - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 238a77b4ef4b..61511c609ca4 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -56,20 +56,20 @@ def get_first_mismatch_ensure_symmetry(a, b): ( [1, 2, 3], [1, 4, 3], - ObjectPath.root().array_index(1), - ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1).attr("value"), ), ( [1, 2, 3], [10, 2, 30], - ObjectPath.root().array_index(0), - ObjectPath.root().array_index(0), + ObjectPath.root().array_index(0).attr("value"), + ObjectPath.root().array_index(0).attr("value"), ), ( [1, 3, 4], [1, 2, 3, 4], - ObjectPath.root().array_index(1), - ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1).attr("value"), ), ( [1, 2, 3], @@ -121,28 +121,14 @@ def test_shape_tuple_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None -@pytest.mark.parametrize( - "contents", - [ - {}, - {"a": 1, "b": 2}, - {"a": True, "b": False}, - ], -) -def test_string_map_structural_equal_to_self(contents): - a = tvm.runtime.convert({**contents}) - b = tvm.runtime.convert({**contents}) - assert get_first_mismatch_ensure_symmetry(a, b) is None - - @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ ( dict(a=3, b=4), dict(a=3, b=5), - ObjectPath.root().map_value("b"), - ObjectPath.root().map_value("b"), + ObjectPath.root().map_value("b").attr("value"), + ObjectPath.root().map_value("b").attr("value"), ), ( dict(a=3, b=4), diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index 1e3249197851..aa482dd65cd7 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,19 +23,16 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1] == 3 + assert a[-1].value == 3 a_slice = a[-3:-1] - assert (a_slice[0], a_slice[1]) == (1, 2) + assert (a_slice[0].value, a_slice[1].value) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3.5, True]) + a = tvm.runtime.convert([1, 2, 3]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1] == 2 - assert a_loaded[2] == 3.5 - assert a_loaded[3] == True - assert isinstance(a_loaded[3], bool) + assert a_loaded[1].value == 2 def test_dir_array(): @@ -69,7 +66,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"] == 2 + assert amap["a"].value == 2 assert "a" in dd assert "b" in dd @@ -81,7 +78,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1] for kv in amap.items()} + dd = {kv[0].name: kv[1].value for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index b70406c1bb7a..2355aa19adec 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -16,7 +16,6 @@ # under the License. """Test type nodes in the IR""" import tvm -from tvm.script import tir as T def check_json_roundtrip(node): @@ -39,9 +38,11 @@ def test_tensor_type_bad_constructor(): def test_tensor_type(): - tt = tvm.ir.TensorType([1, 2, 3], "float32") - assert tt.dtype == "float32" - assert list(tt.shape) == [T.int32(1), T.int32(2), T.int32(3)] + shape = tvm.runtime.convert([1, 2, 3]) + dtype = "float32" + tt = tvm.ir.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape assert tt.span == None str(tt) check_json_roundtrip(tt) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index b0ddbe93601e..f1709c449d16 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" ) @@ -144,7 +144,7 @@ def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer(( vi, vj = T.axis.remap("SS", [i, j]) T.reads(x[vi, vj]) T.writes(y[vi, vj]) - y[vi, vj] = x[vi, vj] + T.float32(1.0) + y[vi, vj] = x[vi, vj] + T.float32(1) @R.function def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 64d5c7381171..97ad9f5dd034 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -404,7 +404,7 @@ def f( "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', "args": '[Var(name_hint="x"), Var(name_hint="y")]', "sinfo_args": "[ObjectStructInfo()]", - "attrs": '{"test_attr": True}', + "attrs": '{"test_attr": 1}', }, extern_call_text, ) diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 1efbd690f034..2ab5afaabf24 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -63,13 +63,6 @@ def foo(x: R.Tensor((2, 3), "float32", "llvm")): def test_dispatch_scanop_cuda(): - """R.cumsum and R.cumprod may be lowered with TOPI for GPU - - For the purpose of testing, this test case intentionally uses the - `exclusive=True` argument to prevent the `R.cumsum` from being - lowered to the packed func `"gpu_2d_continuous_cumsum"`. - """ - @I.ir_module class Before: I.module_global_infos({"vdevice": [I.vdevice("cuda", 0)]}) @@ -77,7 +70,7 @@ class Before: @R.function def main(x: R.Tensor(("m", 3), "float32", "cuda")): with R.dataflow(): - lv0 = R.cumsum(x, axis=1, exclusive=True) + lv0 = R.cumsum(x, axis=1) lv1 = R.cumprod(lv0, axis=1) gv = lv1 R.output(gv) @@ -96,7 +89,6 @@ def main(x: R.Tensor(("m", 3), "float32", "cuda")): topi.cuda.cumsum, x, axis=1, - exclusive=True, ) out = bb.emit_te( topi.cuda.cumprod, diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index e93547d83e3c..7b64eb1dee39 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -395,7 +395,7 @@ def test_call_tir_with_grad(): """ v0: R.Tensor((54, 96), dtype="float32") x = T.int64() -R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) +R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": T.float32(1), "x": x}) """, ) @@ -758,7 +758,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": True}) + R.func_attr({"relax.force_pure": 1}) R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -770,7 +770,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": True}) + R.func_attr({"relax.force_pure": 1}) R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 30fd06d4f14d..ab40e181a35a 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -566,7 +566,7 @@ def main(shape: R.Prim(value="n")): assert func(2) == 4 - with pytest.raises(TypeError): + with pytest.raises(tvm.TVMError): func(ShapeTuple([2])) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 60f096585dfe..9a4817f5fd8a 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -118,10 +118,9 @@ class Expected: @T.prim_func def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__ife"}) - if T.Call( + if T.cast( + T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), "bool", - tvm.ir.Op.get("tir.tvm_call_packed"), - ["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))], ): T.anylist_setitem_call_packed( r, diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index b79713e05ed3..4031790fc383 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -18,7 +18,6 @@ import numpy as np import tvm -from tvm.script import tir as T from tvm import relay from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import * @@ -116,7 +115,7 @@ def test_DataTypePattern(): def test_ShapePattern(): - shape = [T.int32(10), T.int32(10)] + shape = [10, 10] pattern = has_shape(shape) assert isinstance(pattern, ShapePattern) tvm.ir.assert_structural_equal(pattern.shape, shape) diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py index 04662f21ae9e..d703ef1f3d9a 100644 --- a/tests/python/relay/test_executor.py +++ b/tests/python/relay/test_executor.py @@ -57,7 +57,7 @@ def test_create_executor_attr_type_incorrect(): with pytest.raises( TVMError, match='Attribute "interface-api" should have type "runtime.String"' - ' but instead found "runtime.BoxBool"', + ' but instead found "IntImm"', ): Executor("aot", {"interface-api": True}) diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py index db8252f3a3c4..ea15dd0d3c88 100644 --- a/tests/python/relay/test_runtime.py +++ b/tests/python/relay/test_runtime.py @@ -51,7 +51,7 @@ def test_create_runtime_attr_not_found(): def test_create_runtime_attr_type_incorrect(): with pytest.raises( TVMError, - match='Attribute "system-lib" should have type "runtime.BoxBool"' + match='Attribute "system-lib" should have type "IntImm"' ' but instead found "runtime.String"', ): Runtime("crt", {"system-lib": "woof"}) @@ -65,7 +65,7 @@ def test_list_runtimes(): def test_list_runtime_options(runtime): aot_options = Runtime.list_registered_options(runtime) assert "system-lib" in aot_options - assert aot_options["system-lib"] == "runtime.BoxBool" + assert aot_options["system-lib"] == "IntImm" def test_list_runtime_options_not_found(): diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 7d0cd51d3298..f18994d52ce9 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -18,13 +18,12 @@ for expressions. """ import pytest -import numpy as np - import tvm -from tvm import IRModule, relay -from tvm.relay import op, transform +from tvm import IRModule, parser, relay, te +from tvm.relay import analysis, op, transform from tvm.relay.op import op as _op -from tvm.script import tir as T + +import numpy as np def infer_mod(mod, annotate_spans=True): @@ -555,32 +554,40 @@ def test_repeat_register(): assert "Operator custom_log3 is registered before" in str(cm.execption) -@pytest.mark.parametrize("relay_op", [relay.op.argmax, relay.op.argmin]) -@pytest.mark.parametrize( - "shape_dtype", - [ - ("int32", T.int32), - ("int64", T.int64), - ], - ids=["int32", "int64"], -) -def test_argreduce_infer_return_type(relay_op, shape_dtype): +def test_argreduce_infer_return_type(): x_shape = (1, 1) broadcast_shape = [1, 1] - (sdtype, conv) = shape_dtype - - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmax = relay_op(broadcast_to, axis=[1]) - - f = relay.Function([x], argmax) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) + shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] + + # Testing with argmax + for (sdtype, conv) in shape_dtypes: + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmax = relay.op.argmax(broadcast_to, axis=[1]) + + f = relay.Function([x], argmax) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) + + # Testing with argmin + for (sdtype, conv) in shape_dtypes: + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmin = relay.op.argmin(broadcast_to, axis=[1]) + + f = relay.Function([x], argmin) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index e0d216b33e9a..7538075ae7f8 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,13 +15,12 @@ # specific language governing permissions and limitations # under the License. -import pickle -import random - import numpy as np - +import random import tvm import tvm.testing +import pickle +from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -97,123 +96,8 @@ def test_shape_tuple(): assert stuple == z -def test_bool_argument(): - """Boolean objects are currently stored as int""" - func = tvm.get_global_func("testing.AcceptsBool") - - assert isinstance(func(True), bool) - assert isinstance(func(1), bool) - assert isinstance(func(0), bool) - - -def test_int_argument(): - func = tvm.get_global_func("testing.AcceptsInt") - - assert isinstance(func(True), int) - assert isinstance(func(1), int) - assert isinstance(func(0), int) - - -def test_object_ref_argument(): - func = tvm.get_global_func("testing.AcceptsObjectRef") - - assert isinstance(func(True), bool) - assert isinstance(func(1), int) - assert isinstance(func(3.5), float) - assert func(3.5) == 3.5 - - -def test_object_ref_array_argument(): - func = tvm.get_global_func("testing.AcceptsObjectRefArray") - - assert isinstance(func([True, 17, "hello"]), bool) - assert isinstance(func([True]), bool) - assert isinstance(func([17]), int) - assert isinstance(func(["hello"]), str) - - -def test_map_argument_returns_value(): - func = tvm.get_global_func("testing.AcceptsMapReturnsValue") - - res = func({"a": 1, "b": 2}, "a") - assert isinstance(res, int) - assert res == 1 - - res = func({"a": True, "b": False}, "a") - assert isinstance(res, bool) - assert res == True - - -def test_map_argument_returns_map(): - func = tvm.get_global_func("testing.AcceptsMapReturnsMap") - - res = func({"a": 1, "b": 2}) - for key, value in res.items(): - assert isinstance(key, str) - assert isinstance(value, int) - - res = func({"a": False, "b": True}) - for key, value in res.items(): - assert isinstance(key, str) - assert isinstance(value, bool) - - -def test_conversion_of_arg(): - """Arguments may be converted - - The calling side of the FFI converts to types that are available - at runtime. However, there may be additional type conversions - required, that must be performed on the callee-side of the FFI. - """ - - func = tvm.get_global_func("testing.AcceptsPrimExpr") - - res = func(1) - assert isinstance(res, tvm.tir.IntImm) - assert res.dtype == "int32" - - res = func(True) - assert isinstance(res, tvm.tir.IntImm) - assert res.dtype == "bool" - - -def test_conversion_of_array_elements(): - """Elements of an array may require conversion from FFI to param type - - Like `test_conversion_of_arg`, but conversions must be applied - recursively to array elements. Here, the Python-side of the FFI - converts the array `[1,2]` to `Array{runtime::Int(1), - runtime::Int(2)}`, and the C++ side of the FFI converts to - `Array{IntImm(1), IntImm(2)}`. - """ - - func = tvm.get_global_func("testing.AcceptsArrayOfPrimExpr") - - res = func([1, False]) - assert isinstance(res[0], tvm.tir.IntImm) - assert res[0].dtype == "int32" - assert isinstance(res[1], tvm.tir.IntImm) - assert res[1].dtype == "bool" - - -def test_conversion_of_map_values(): - """Elements of a map may require conversion from FFI to param type - - Like `test_conversion_of_arg`, but conversions must be applied - recursively to map elements. Here, the Python-side of the FFI - converts the map `{'a':1, 'b':2}` to `Map{{"a", runtime::Int(1)}, - {"b", runtime::Int(2)}}`, and the C++ side of the FFI converts to - `Map{{"a", IntImm(1)}, {"b", IntImm(2)}}`. - """ - - func = tvm.get_global_func("testing.AcceptsMapOfPrimExpr") - - res = func({"a": 1, "b": False}) - assert isinstance(res["a"], tvm.tir.IntImm) - assert res["a"].dtype == "int32" - assert isinstance(res["b"], tvm.tir.IntImm) - assert res["b"].dtype == "bool" - - if __name__ == "__main__": - tvm.testing.main() + test_string() + test_adt_constructor() + test_tuple_object() + test_shape_tuple() diff --git a/tests/python/te/test_te_schedule_tensorize.py b/tests/python/te/test_te_schedule_tensorize.py index 419d3edb5c3d..79aecb78902a 100644 --- a/tests/python/te/test_te_schedule_tensorize.py +++ b/tests/python/te/test_te_schedule_tensorize.py @@ -16,7 +16,6 @@ # under the License. import tvm from tvm import te -from tvm.script import tir as T def intrin_vadd(xo, m, n): @@ -101,7 +100,6 @@ def add(m): def check(m, factor): x, y, z = add(m) - factor = T.int32(factor) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) vadd = intrin_vadd(xo, m, factor) @@ -135,7 +133,7 @@ def check_cache_write(m, factor): finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z_global], dom_map) # outer loop var will be rebased, so min value is the new loop var and extent is 1 - tvm.ir.assert_structural_equal(out_dom[xo].extent, T.int32(1)) + tvm.ir.assert_structural_equal(out_dom[xo].extent, 1) assert isinstance(out_dom[xo].min, tvm.tir.Var) assert xo.var.name == out_dom[xo].min.name @@ -185,7 +183,7 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -209,7 +207,7 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -232,7 +230,7 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -256,7 +254,7 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -266,10 +264,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) - check(T.int32(16)) - check_rfactor(T.int32(16), T.int32(16)) - check_rfactor_no_reset(T.int32(16), T.int32(16)) - check_rfactor_no_reset_multi_reduction(T.int32(16), T.int32(16)) + check(16) + check_rfactor(16, 16) + check_rfactor_no_reset(16, 16) + check_rfactor_no_reset_multi_reduction(16, 16) # This tests whether algorithm and intrinsics expressions are simplified diff --git a/tests/python/te/test_te_tag.py b/tests/python/te/test_te_tag.py index a4b76e7d6736..6e88a12614cf 100644 --- a/tests/python/te/test_te_tag.py +++ b/tests/python/te/test_te_tag.py @@ -57,12 +57,12 @@ def test_with(): assert C.op.tag == "gemm" assert "hello" in C.op.attrs assert "xx" not in C.op.attrs - assert C.op.attrs["hello"] == 1 + assert C.op.attrs["hello"].value == 1 CC = tvm.ir.load_json(tvm.ir.save_json(C)) - assert CC.op.attrs["hello"] == 1 - assert len(CC.op.attrs["arr"]) == 2 - assert CC.op.attrs["arr"][0] == 10 - assert CC.op.attrs["arr"][1] == 12 + assert CC.op.attrs["hello"].value == 1 + assert CC.op.attrs["arr"][0].value == 10 + # str format happened to be json compatible + assert json.loads(str(CC.op.attrs))["arr"][1] == 12 def test_decorator(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index 0e610cc1659b..e94a4f09ec56 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -122,7 +122,7 @@ def test_lower_build_tir_func(): def test_lower_build_tir_module(): func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", T.bool(True)) + func = func.with_attr("tir.noalias", True) ir_mod = IRModule({"main": func}) # check lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index d706e65d8186..b4b773197b14 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -14,15 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import pytest import tvm import tvm.testing from tvm import te from tvm.tir import Buffer -from tvm.script import tir as T - import numpy as np -import pytest def test_buffer(): @@ -81,9 +78,9 @@ def test_buffer_access_ptr_extent(): # Test extent from input params aptr = Ab.access_ptr("rw", extent=200) - tvm.ir.assert_structural_equal(aptr.args[3], T.int32(200)) + tvm.ir.assert_structural_equal(aptr.args[3], 200) aptr = Ab.access_ptr("rw", offset=100, extent=100) - tvm.ir.assert_structural_equal(aptr.args[3], T.int32(100)) + tvm.ir.assert_structural_equal(aptr.args[3], 100) def test_buffer_vload(): @@ -91,7 +88,7 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)]) + tvm.ir.assert_structural_equal(load.indices, [2, 3]) def test_buffer_offset_of(): @@ -262,7 +259,7 @@ def test_buffer_flatten(): buf = tvm.tir.decl_buffer([16, 32]) flat = buf.get_flattened_buffer() assert buf.data.same_as(flat.data) - tvm.ir.assert_structural_equal(flat.shape, [T.int32(16 * 32)]) + tvm.ir.assert_structural_equal(flat.shape, [16 * 32]) def test_buffer_flatten_preserves_identity(): @@ -276,8 +273,8 @@ def test_buffer_flatten_uses_axis_separators(): """Flattening to N-d physical buffers uses the axis separators""" buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) flat = buf.get_flattened_buffer() - tvm.ir.assert_structural_equal(flat.axis_separators, [T.int32(1)]) - tvm.ir.assert_structural_equal(flat.shape, [T.int32(4 * 16), T.int32(32)]) + tvm.ir.assert_structural_equal(flat.axis_separators, [1]) + tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) def test_invalid_axis_separators_raises_exception(): diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tir-base/test_tir_index_map.py index 3ddbd2f69f59..e893ed897d65 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tir-base/test_tir_index_map.py @@ -22,7 +22,6 @@ from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.tir import IndexMap, IntImm, floordiv, floormod -from tvm.script import tir as T def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -38,22 +37,28 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_indices([0]), [T.int32(0), T.int32(0)]) - assert_structural_equal(index_map.map_indices([3]), [T.int32(0), T.int32(3)]) - assert_structural_equal(index_map.map_indices([4]), [T.int32(1), T.int32(0)]) - assert_structural_equal(index_map.map_indices([42]), [T.int32(10), T.int32(2)]) - assert_structural_equal(index_map.map_indices([T.int64(42)]), [T.int64(10), T.int64(2)]) + assert_structural_equal(index_map.map_indices([0]), [0, 0]) + assert_structural_equal(index_map.map_indices([3]), [0, 3]) + assert_structural_equal(index_map.map_indices([4]), [1, 0]) + assert_structural_equal(index_map.map_indices([42]), [10, 2]) + assert_structural_equal( + index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")] + ) def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_shape([4]), [T.int32(1), T.int32(4)]) - assert_structural_equal(index_map.map_shape([16]), [T.int32(4), T.int32(4)]) + assert_structural_equal(index_map.map_shape([4]), [1, 4]) + assert_structural_equal(index_map.map_shape([16]), [4, 4]) - assert_structural_equal(index_map.map_shape([14]), [T.int32(4), T.int32(4)]) - assert_structural_equal(index_map.map_shape([T.int64(16)]), [T.int64(4), T.int64(4)]) - assert_structural_equal(index_map.map_shape([T.int64(14)]), [T.int64(4), T.int64(4)]) + assert_structural_equal(index_map.map_shape([14]), [4, 4]) + assert_structural_equal( + index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")] + ) + assert_structural_equal( + index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")] + ) def test_inverse(): @@ -77,28 +82,28 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[16], - post_shape=[T.int32(4), T.int32(4)], + post_shape=[4, 4], padding=lambda i, j: tvm.runtime.convert(False), ), "right_padding": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[15], - post_shape=[T.int32(4), T.int32(4)], + post_shape=[4, 4], padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[15], - post_shape=[T.int32(4), T.int32(4)], + post_shape=[4, 4], padding=lambda i, j: tvm.tir.And(i == 0, j < 1), ), "left_and_right_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[14], - post_shape=[T.int32(4), T.int32(4)], + post_shape=[4, 4], padding=lambda i, j: tvm.tir.Or( tvm.tir.And(i == 0, j < 1), tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), @@ -108,7 +113,7 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[dynamic_N], - post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, T.int32(4)], + post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, 4], padding=lambda i, j: tvm.tir.And( dynamic_N % (-4) != 0, tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), @@ -122,10 +127,10 @@ def test_nonbijective_inverse_gives_error(): ], pre_shape=[14, 31], post_shape=[ - T.int32(4), # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 - T.int32(5), # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 - T.int32(4), # Range of iter%4 - T.int32(8), # Range of iter%8 + 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + 4, # Range of iter%4 + 8, # Range of iter%8 ], padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( tvm.tir.Or( @@ -142,35 +147,35 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 32, (i // 4) % 8, i % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[T.int32(4), T.int32(8), T.int32(4)], + post_shape=[4, 8, 4], padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_right_padding_transpose": dict( forward=lambda i: [(i // 4) % 8, i // 32, i % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[T.int32(8), T.int32(4), T.int32(4)], + post_shape=[8, 4, 4], padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_left_padding": dict( forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[T.int32(4), T.int32(8), T.int32(4)], + post_shape=[4, 8, 4], padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "multiple_left_padding_with_transpose": dict( forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[T.int32(8), T.int32(4), T.int32(4)], + post_shape=[8, 4, 4], padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "outer_loop_extent_one": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [i * 4 + j], pre_shape=[3], - post_shape=[T.int32(1), T.int32(4)], + post_shape=[1, 4], padding=lambda i, j: tvm.runtime.convert(3) == j, ), } diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 29efd95280be..eeedae1f127c 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -32,7 +32,7 @@ def test_te_const(): assert isinstance(x, tvm.tir.IntImm) -def test_tir_const_dtype_inference(): +def test_scalar_dtype_inference(): for data in [ True, bool(1), @@ -49,11 +49,28 @@ def test_tir_const_dtype_inference(): np.float64(1), ]: assert tvm.tir.const(data).dtype == str(np.array(data).dtype) - - assert tvm.tir.const(True).dtype == "bool" assert tvm.tir.const(1).dtype == "int32" assert tvm.tir.const(1.0).dtype == "float32" + for data in [ + True, + bool(1), + np.uint8(1), + np.uint16(1), + np.uint32(1), + np.uint64(1), + np.int8(1), + np.int16(1), + np.int32(1), + np.int64(1), + np.float16(1), + np.float32(1), + np.float64(1), + ]: + assert tvm.runtime.convert(data).dtype == str(np.array(data).dtype) + assert tvm.runtime.convert(1).dtype == "int32" + assert tvm.runtime.convert(1.0).dtype == "float32" + def test_make(): x = tvm.tir.const(1, "int32") @@ -116,7 +133,7 @@ def test_attr(): assert stmt.node == y a = tvm.runtime.convert(1) - assert a == 1 + assert a.value == 1 try: a.no_field assert False @@ -333,7 +350,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) - assert f2.attrs["calling_conv"] == 1 + assert f2.attrs["calling_conv"].value == 1 assert not func.attrs diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index 8ae576e9b922..c2f3f89e6e12 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -146,7 +146,7 @@ def test_sample_categorical_serialize(): decisions.append(rv) new_sch = verify_trace_roundtrip(sch, mod=elementwise) for i, new_inst in enumerate(new_sch.trace.insts): - assert decisions[i] == candidates[new_sch.trace.decisions[new_inst]] + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] def test_sample_perfect_tile_power_of_two(): diff --git a/tests/python/tir-schedule/test_tir_schedule_state.py b/tests/python/tir-schedule/test_tir_schedule_state.py index c023b9dbc59d..74880e5a42d9 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state.py +++ b/tests/python/tir-schedule/test_tir_schedule_state.py @@ -155,10 +155,10 @@ def test_replace_direct_write0(): old_hash = s.mod["main"].__hash__() sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) - # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) # There is no other reference so the AST node can be written directly assert old_hash == s.mod["main"].__hash__() + # Check the replaced part is equal to the target + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) # The target reuse the stmt of the sref, so the sref won't be None assert sref.stmt is not None diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py index cb7151f875e3..d5d5e0634ef6 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py @@ -1029,45 +1029,38 @@ class TestTileAwareCompaction(BaseCompactTest): # it is not an opaque block case intentionally is_lower_order_free = False - @property - def before(self): - @T.prim_func - def main( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ): - for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - A_local = T.decl_buffer((26, 128), scope="local") - B_local = T.decl_buffer((128, 26), scope="local") - C_local = T.decl_buffer((26, 26), scope="local") - for ax0, ax1 in T.grid(26, 128): - if i_0 * 26 + ax0 < 128: - A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] - for ax0, ax1 in T.grid(128, 26): - if j_0 * 26 + ax1 < 128: - B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] - for i_1, j_1, k in T.grid(26, 26, 128): - if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: - if k == 0: - C_local[i_1, j_1] = T.float32(0) - C_local[i_1, j_1] = ( - C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] - ) - for ax0, ax1 in T.grid(26, 26): - if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: - C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] - - # Get partitioned workload to compact - mod = tvm.IRModule.from_expr(main) - with tvm.transform.PassContext( - config={"tir.LoopPartition": {"partition_const_loop": True}} - ): - mod = tvm.tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.LoopPartition()(mod) - - return mod["main"] + @T.prim_func + def before( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + A_local = T.decl_buffer((26, 128), scope="local") + B_local = T.decl_buffer((128, 26), scope="local") + C_local = T.decl_buffer((26, 26), scope="local") + for ax0, ax1 in T.grid(26, 128): + if i_0 * 26 + ax0 < 128: + A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] + for ax0, ax1 in T.grid(128, 26): + if j_0 * 26 + ax1 < 128: + B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] + for i_1, j_1, k in T.grid(26, 26, 128): + if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: + if k == 0: + C_local[i_1, j_1] = T.float32(0) + C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] + for ax0, ax1 in T.grid(26, 26): + if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: + C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] + + # Get partitioned workload to compact + before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): + before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod) + before_mod = tvm.tir.transform.LoopPartition()(before_mod) + before = before_mod["main"] @T.prim_func def expected( diff --git a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py index 3078572bb508..9f61b5a3920a 100644 --- a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py @@ -14,12 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import pytest import tvm import tvm.testing -from tvm import te, tir - -import pytest +from tvm import te import numpy as np @@ -186,7 +184,7 @@ def collect_branch_stmt(x): if isinstance(x, tvm.tir.IfThenElse): branch_collector.append(x) - n = tir.const(21) + n = 21 A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 0b43db56f300..23a51a0817df 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -394,144 +394,5 @@ def func_without_arg( tvm.ir.assert_structural_equal(Expected, After) -def test_int_parameter(): - """Boolean may be passed to functions accepting int - - A PackedFunc produced by compiling an IRModule should support the - same type conversions as the C++ implementation. When a function - accepts an integer argument, the caller may call it with a boolean - value. - - This also provides backwards compatibility for functions that were - defined as accepting an integer, but are called with a boolean - argument. Prior to PackedFunc interface supporting boolean - arguments directly, the argument would be converted from boolean - to integer to be stored in a TVMValue. After adding support for - boolean arguments, this usage should not cause an error. - - """ - - @I.ir_module - class Before: - @T.prim_func - def main(arg: T.int32) -> T.int32: - T.func_attr({"target": T.target("llvm", host="llvm")}) - if arg > 0: - return 10 - else: - return 20 - - @I.ir_module - class Expected: - @T.prim_func - def main( - args: T.handle, - arg_type_ids: T.handle("int32"), - num_args: T.int32, - out_ret_value: T.handle("void"), - out_ret_tcode: T.handle("int32"), - resource_handle: T.handle, - ) -> T.int32: - T.func_attr( - { - "calling_conv": 1, - "target": T.target("llvm"), - } - ) - assert num_args == 1, "main: num_args should be 1" - assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" - assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" - arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) - arg_code: T.int32 = arg_type_ids_1[0] - assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" - arg: T.int32 = T.if_then_else( - arg_code == 0, - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), - ) - with T.attr(0, "compute_scope", "main_compute_"): - out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) - out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) - if arg > 0: - out_ret_value_1[0] = T.Cast("int64", 10) - out_ret_tcode_1[0] = 0 - return 0 - else: - out_ret_value_1[0] = T.Cast("int64", 20) - out_ret_tcode_1[0] = 0 - return 0 - return 0 - - After = tvm.tir.transform.MakePackedAPI()(Before) - - tvm.ir.assert_structural_equal(Expected, After) - - -def test_bool_parameter(): - """An integer may be passed to a function acccepting Boolean - - A PackedFunc produced by compiling an IRModule should support the - same type conversions as the C++ implementation. When a function - accepts a boolean argument, the caller may call it with an integer - value. - - """ - - @I.ir_module - class Before: - @T.prim_func - def main(arg: T.bool) -> T.int32: - T.func_attr({"target": T.target("llvm", host="llvm")}) - if arg: - return 10 - else: - return 20 - - @I.ir_module - class Expected: - @T.prim_func - def main( - args: T.handle, - arg_type_ids: T.handle("int32"), - num_args: T.int32, - out_ret_value: T.handle("void"), - out_ret_tcode: T.handle("int32"), - resource_handle: T.handle, - ) -> T.int32: - T.func_attr( - { - "calling_conv": 1, - "target": T.target("llvm"), - } - ) - assert num_args == 1, "main: num_args should be 1" - assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" - assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" - arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) - arg_code: T.int32 = arg_type_ids_1[0] - assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" - arg: T.bool = T.if_then_else( - arg_code == 15, - T.tvm_struct_get(args, 0, 12, "bool"), - T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), - ) - with T.attr(0, "compute_scope", "main_compute_"): - out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) - out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) - if arg: - out_ret_value_1[0] = T.Cast("int64", 10) - out_ret_tcode_1[0] = 0 - return 0 - else: - out_ret_value_1[0] = T.Cast("int64", 20) - out_ret_tcode_1[0] = 0 - return 0 - return 0 - - After = tvm.tir.transform.MakePackedAPI()(Before) - - tvm.ir.assert_structural_equal(Expected, After) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 68149e7d64bb..4b71eb825414 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -937,8 +937,8 @@ def test_vulkan_smem_reuse(): "kind": "vulkan", "max_num_threads": 256, "max_threads_per_block": 256, - "supports_float32": True, - "supports_int32": True, + "supports_float32": T.bool(True), + "supports_int32": T.bool(True), "tag": "", "thread_warp_size": 1, } diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index d8212d38854c..279785fdca51 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -332,35 +332,26 @@ def convert_slice_to_bufferload() -> None: check_error(convert_slice_to_bufferload, 6) -def test_tvm_exception_catch_from_special_stmt(): +def test_tvm_exception_catch(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) - check_error(special_stmt_except, 2) - - -def test_tvm_exception_catch_from_scope_handler(): def scope_handler_except() -> None: for i in T.serial("1", "1"): # error T.evaluate(1) - check_error(scope_handler_except, 2) - - -def test_tvm_exception_catch_from_bare_intrin(): def intrin_except_unassign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.evaluate(A) # error - check_error(intrin_except_unassign, 3) - - -def test_tvm_exception_catch_from_assigned_intrin(): def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") A[0, 0] = A[A] # error + check_error(special_stmt_except, 2) + check_error(scope_handler_except, 2) + check_error(intrin_except_unassign, 3) check_error(intrin_except_assign, 3) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index b7ba57fa9387..8364e65a4178 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -230,7 +230,7 @@ def test_buffer_store(): obj, """ A = T.Buffer((128, 128), "float16") -A[128, 128] = A[128, 128] + T.float16(1.0) +A[128, 128] = A[128, 128] + T.float16(1) """, ) @@ -259,7 +259,7 @@ def test_let_stmt(): _assert_print( obj, """ -with T.LetStmt(T.float32(10.0)) as v: +with T.LetStmt(T.float32(10)) as v: T.evaluate(0) """, ) @@ -672,7 +672,7 @@ def test_call(): _assert_print( obj, """ -T.atan(T.float32(1.0)) +T.atan(T.float32(1)) """, ) @@ -682,7 +682,7 @@ def test_comm_reducer(): _assert_print( obj, """ -T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]) +T.comm_reducer(lambda x, y: x + y, [T.float32(0)]) """, ) @@ -712,7 +712,7 @@ def test_float_imm(): _assert_print( obj, """ -T.float16(1.0) +T.float16(1) """, ) @@ -942,7 +942,7 @@ def func(): @T.prim_func def func(): - T.evaluate(T.{dtype}(0.0)) + T.evaluate(T.{dtype}(0)) """ func = get_func(dtype) _assert_print(func, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index b44ff5ad7241..f81a80de6d61 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) + tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) + tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) def block_elements(): @@ -3981,32 +3981,6 @@ def func() -> T.int32: return func -def func_attr_with_list(): - @T.prim_func - def func( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - D: T.Buffer((128, 128), "float32"), - ) -> None: - T.func_attr( - {"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [T.int32(1)]} - ) - C = T.alloc_buffer([128, 128], dtype="float32") - for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C"): - x, y, k = T.axis.remap("SSR", [i0, i1, i2]) - with T.init(): - C[x, y] = T.float32(0) - C[x, y] = C[x, y] + A[x, k] * B[y, k] - for i0, i1 in T.grid(128, 128): - with T.block("D"): - T.block_attr({"layout_free_placeholders": [C]}) - x, y = T.axis.remap("SS", [i0, i1]) - D[x, y] = C[x, y] + T.float32(1) - - return func - - def op_of_literal(): op_list = [ (T.exp, 0), @@ -4224,7 +4198,6 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero, return_zero_private, return_zero_private_with_attr, - func_attr_with_list, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index ae83a9d66392..9bc9800c1cb8 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -19,7 +19,6 @@ import tvm from tvm import te from tvm.topi import utils -from tvm.script import tir as T from .environment import get_env @@ -1047,19 +1046,19 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) + analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 ) tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) + analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 ) - tvm.ir.assert_structural_equal(src_coeff[-2], T.int32(1)) - tvm.ir.assert_structural_equal(dst_coeff[-2], T.int32(1)) + tvm.ir.assert_structural_equal(src_coeff[-2], 1) + tvm.ir.assert_structural_equal(dst_coeff[-2], 1) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - tvm.ir.assert_structural_equal(src_coeff[-3], T.int32(env.BLOCK_OUT)) - tvm.ir.assert_structural_equal(dst_coeff[-3], T.int32(env.BLOCK_OUT)) + tvm.ir.assert_structural_equal(src_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(dst_coeff[-3], env.BLOCK_OUT) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1] From 1fcb62023f0a5f878abd5b43ec9e547933fb5fab Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:39:43 -0400 Subject: [PATCH 052/202] [WebGPU] Fix unexpected device lost error when intentional dispose (#17250) --- web/src/runtime.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/src/runtime.ts b/web/src/runtime.ts index d71c98e7d1bc..e446c4dc4dfb 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1122,7 +1122,7 @@ export class Instance implements Disposable { // ctx release goes back into lib. this.ctx.dispose(); this.lib.dispose(); - this.deviceLostIsError = true; + // Cannot set deviceLostIsError back to true here because GPUDevice.destroy() is asynchronous. } /** @@ -2122,6 +2122,7 @@ export class Instance implements Disposable { this.dispose(); } }); + this.deviceLostIsError = true; const webGPUContext = new WebGPUContext( this.memory, device From 77391714ab714afcc849fde1378a5a0c62d99c2e Mon Sep 17 00:00:00 2001 From: sdalvi-quic <135273488+sdalvi-quic@users.noreply.github.com> Date: Fri, 9 Aug 2024 00:27:35 -0500 Subject: [PATCH 053/202] Replacing unary ops with LookUpTable and Take op to improve performance (#17214) * Created Look Up Table for unary ops such that the values are computed during compile time and take op is used to access the values at runtime * Black formatting for hexagon_unary_ops.py * minor edit * Accessed variables with op attributes and op name in the prim fucn definition. Added check if the call node is of call tir type --- .../tvm/contrib/hexagon/generate_take_op.py | 98 +++++ .../tvm/contrib/hexagon/hexagon_unary_ops.py | 97 +++++ .../python/contrib/test_hexagon/test_take.py | 393 ++++++++++++++++++ 3 files changed, 588 insertions(+) create mode 100644 python/tvm/contrib/hexagon/generate_take_op.py create mode 100644 python/tvm/contrib/hexagon/hexagon_unary_ops.py create mode 100644 tests/python/contrib/test_hexagon/test_take.py diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py new file mode 100644 index 000000000000..b70eb451a1a5 --- /dev/null +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name, unnecessary-comprehension, unused-argument + +import tvm +import tvm.testing +from tvm import relax +from tvm.contrib.hexagon import hexagon_unary_ops + + +def op_replace(call_node, func) -> bool: + if not isinstance(call_node, relax.Call): + return False + call_tir_op = tvm.ir.Op.get("relax.call_tir") + if call_node.op != call_tir_op: + return False + ops = [ + "qnn.tanh", + "qnn.sqrt", + "qnn.rsqrt", + "qnn.exp", + "qnn.erf", + "qnn.sigmoid", + "qnn.hardswish", + "qnn.log", + "qnn.abs", + ] + if func.attrs["op_attrs"]["op_name"] in ops: + return True + return False + + +@relax.expr_functor.mutator +class Tanh2TakeReplace(tvm.relax.PyExprMutator): + def __init__(self, mod: tvm.IRModule) -> None: + super().__init__(mod) + self.mod_ = mod + + def transform(self) -> tvm.IRModule: + # Iterate over all the nodes to check for the node replaceable + for global_var, func in self.mod_.functions.items(): + # Skip non-relax functions + if not isinstance(func, relax.Function): + continue + updated_func = self.visit_expr(func) + self.builder_.normalize(updated_func) + self.builder_.update_func(global_var, updated_func) + # At the end of the transformation we return the updated IRModule from the BlockBuilder. + return self.builder_.get() + + def visit_call_(self, call_node: relax.Call) -> relax.Call: + call_tir_op = tvm.ir.Op.get("relax.call_tir") + if call_node.op != call_tir_op: + return call_node + + var = call_node.args[0] + func = self.mod_[var] + + if call_node.args[1][0].struct_info.dtype == "uint8": + if op_replace(call_node, func): + inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]] + # LUT node creation + LUT = hexagon_unary_ops.LUT_generation( + inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint + ) + # Take operation node creation + take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.struct_info) + take_func = take_func.without_attr("global_symbol") + take_func_gv = self.builder_.add_func(take_func, "take") + take_node = relax.call_tir( + take_func_gv, + relax.expr.Tuple( + [call_node.args[1][0], relax.expr.Constant(tvm.nd.array(LUT))] + ), + call_node.struct_info, + ) + return take_node + return call_node + + +@tvm.ir.transform.module_pass(opt_level=2, name="replace_tanh_take") +class PassReplaceWithTakeOpPrimFuncs: + def transform_module(self, mod, ctx): + return Tanh2TakeReplace(mod).transform() diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py b/python/tvm/contrib/hexagon/hexagon_unary_ops.py new file mode 100644 index 000000000000..1bb4d4ba4f7c --- /dev/null +++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import logging +import numpy as np +from scipy import special +from tvm import te + +logger = logging.getLogger(__name__) + +###################################################################### +#################### PRIMFUNC FOR LUT and Take Op #################### +###################################################################### + + +def saturate(x: te.Tensor, dtype: str): + """Saturate value for the specified data type""" + return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + + +def hardswish_func(x): + x_2 = np.add(x, 3.0) + x_2 = np.clip(x_2, 0.0, 6.0) + return x * x_2 / 6.0 + + +def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None: + LUT = [] + for i in range(256): + i = np.int32(i) + # converting the constants to the numpy value + if inp_zp.data.shape == (): + i_zp = inp_zp.data.numpy()[()] + if inp_scale.data.shape == (): + i_scale = inp_scale.data.numpy()[()] + if out_zp.data.shape == (): + o_zp = out_zp.data.numpy()[()] + if out_scale.data.shape == (): + o_scale = out_scale.data.numpy()[()] + # Dequantization followed by computing the op value + dequant = (i - i_zp) * i_scale + if "tanh" in op_name: + op_val = np.tanh(dequant) + elif "rsqrt" in op_name: + op_val = 1 / np.sqrt(dequant) + elif "sqrt" in op_name: + op_val = np.sqrt(dequant) + elif "exp" in op_name: + op_val = np.exp(dequant) + elif "erf" in op_name: + op_val = special.erf(dequant) + elif "sigmoid" in op_name: + op_val = 1 / (1 + np.exp(np.negative(dequant))) + elif "hardswish" in op_name: + op_val = hardswish_func(dequant) + elif "log" in op_name: + op_val = np.log(dequant) + elif "abs" in op_name: + op_val = np.abs(dequant) + else: + logger.error("Error op is other than unary op") + + # Quantizing the value generated and appending in the Look Up Table + quant = np.round((op_val) / o_scale) + o_zp + val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8) + LUT.append(val) + return LUT + + +def generate_take_primfunc(inp, struct_info): + # Generating the take op + N, H, W, C = inp.struct_info.shape + data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data") + LUT_func = te.placeholder((256,), dtype="uint8", name="LUT") + take = te.compute( + struct_info.shape, + lambda *indices: saturate( + (LUT_func[data[indices].astype("uint8")]), struct_info.dtype + ).astype(struct_info.dtype), + name="take_op", + ) + mod = te.create_prim_func([data, LUT_func, take]) + return mod diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py new file mode 100644 index 000000000000..80c2b053395f --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_take.py @@ -0,0 +1,393 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name, unused-argument, not-callable +import numpy as np +from scipy import special + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import tir as T, relax as R +from tvm.contrib.hexagon import generate_take_op +from tvm.contrib.hexagon import hexagon_unary_ops + +from .infrastructure import quantize_np + + +# Testing the structural and value correctness on replacing unary op with take op. + + +@tvm.script.ir_module +class Module_tanh: + @R.function + def main( + input_tanh: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_tanh.tanh, + ( + input_tanh, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002631544131858676, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def tanh( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}}) + + +@tvm.script.ir_module +class Module_sqrt: + @R.function + def main( + input_sqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_sqrt.sqrt, + ( + input_sqrt, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.003535157327728918, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def sqrt( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}}) + + +@tvm.script.ir_module +class Module_rsqrt: + @R.function + def main( + input_rsqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_rsqrt.rsqrt, + ( + input_rsqrt, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.008154160766635542, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def rsqrt( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.rsqrt"}}) + + +@tvm.script.ir_module +class Module_exp: + @R.function + def main( + input_exp: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_exp.exp, + ( + input_exp, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.008838622987079832, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def exp( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.exp"}}) + + +@tvm.script.ir_module +class Module_erf: + @R.function + def main( + input_erf: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_erf.erf, + ( + input_erf, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002939393251118067, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def erf( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.erf"}}) + + +@tvm.script.ir_module +class Module_sigmoid: + @R.function + def main( + input_sigmoid: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_sigmoid.sigmoid, + ( + input_sigmoid, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002631544131858676, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def sigmoid( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sigmoid"}}) + + +@tvm.script.ir_module +class Module_hardswish: + @R.function + def main( + input_hardswish: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_hardswish.hardswish, + ( + input_hardswish, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0020250332087720325, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def hardswish( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.hardswish"}}) + + +@tvm.script.ir_module +class Module_log: + @R.function + def main( + input_log: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_log.log, + ( + input_log, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0057414634248614226, "float32"), + R.const(255, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def log( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.log"}}) + + +@tvm.script.ir_module +class Module_abs: + @R.function + def main( + input_abs: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_abs.abs, + ( + input_abs, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0031868210196078434, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def abs( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.abs"}}) + + +# data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode the data +# so that we can get the quantization parameters and use them as input to the main func +data = [ + [ + [[0.3034368, 0.60848576], [0.29697746, 0.67340654]], + [[0.656068, 0.23129226], [0.42117321, 0.81263936]], + ] +] +dtype = "uint8" + +# Quantizing input : scale is returned as float64 and zp is returned as int32 +inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype) +inp_quant = tvm.nd.array(inp_quant.astype(np.uint8)) + + +# Test the implementations value output with numpy data. First the IR is runn through pass +# to replace unary op with take op. Followed by value testing. +def test_value(): + ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"] + + atol_val = 2 + for op_name in ops: + if op_name == "tanh": + op_val = np.tanh(data) + before = Module_tanh + elif op_name == "sqrt": + op_val = np.sqrt(data) + before = Module_sqrt + elif op_name == "rsqrt": + op_val = 1 / np.sqrt(data) + before = Module_rsqrt + elif op_name == "exp": + op_val = np.exp(data) + before = Module_exp + elif op_name == "erf": + op_val = special.erf(data) + before = Module_erf + elif op_name == "sigmoid": + op_val = 1 / (1 + np.exp(np.negative(data))) + atol_val = 15 + before = Module_sigmoid + elif op_name == "hardswish": + op_val = hexagon_unary_ops.hardswish_func(data) + before = Module_hardswish + elif op_name == "log": + op_val = np.log(data) + before = Module_log + elif op_name == "abs": + op_val = np.abs(data) + before = Module_abs + + # Quantizing output : scale is returned as float64 and zp is returned as int32 + out_quant, _, _ = quantize_np(op_val, dtype) + + after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(before) + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(after, target, exec_mode="compiled") + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](inp_quant) + + tvm.testing.assert_allclose(res.numpy(), out_quant, atol=atol_val) + print("Passed Value : ", op_name) + + +# Testing the structural implementation, if the unary op is replaced with take op. +def test_structural(): + Modules = [ + Module_tanh, + Module_sqrt, + Module_rsqrt, + Module_exp, + Module_erf, + Module_sigmoid, + Module_hardswish, + Module_log, + Module_abs, + ] + for mod in Modules: + after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod) + assert not tvm.ir.structural_equal(after["main"], mod["main"]) + print("Passed Structural") From b40a02c265ad029a6dec2eef808b48945e39c31b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 9 Aug 2024 21:44:14 +0800 Subject: [PATCH 054/202] [Relax] Add KVCache Interface for Relax NNModule (#17261) Introduce kv cache interface for Relax NNModule to support paged attention. Note that the implementation is migrated from MLC-llm Co-authored-by: Bohan Hou Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin Co-authored-by: krishnaraj36 --- python/tvm/relax/frontend/nn/llm/__init__.py | 22 + python/tvm/relax/frontend/nn/llm/kv_cache.py | 1636 +++++++++++++++ .../frontend/nn/llm/position_embedding.py | 287 +++ python/tvm/relax/frontend/nn/llm/tree_attn.py | 411 ++++ ...me_builtin_paged_attention_kv_cache_tir.py | 1765 +---------------- 5 files changed, 2371 insertions(+), 1750 deletions(-) create mode 100644 python/tvm/relax/frontend/nn/llm/__init__.py create mode 100644 python/tvm/relax/frontend/nn/llm/kv_cache.py create mode 100644 python/tvm/relax/frontend/nn/llm/position_embedding.py create mode 100644 python/tvm/relax/frontend/nn/llm/tree_attn.py diff --git a/python/tvm/relax/frontend/nn/llm/__init__.py b/python/tvm/relax/frontend/nn/llm/__init__.py new file mode 100644 index 000000000000..03c86880bbb1 --- /dev/null +++ b/python/tvm/relax/frontend/nn/llm/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""LLM support for PyTorch-like API to build IRModules.""" + +from . import kv_cache, position_embedding +from .position_embedding import llama_rope +from .tree_attn import tree_attn +from .kv_cache import PagedKVCache diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py new file mode 100644 index 000000000000..25a3a1a00ddc --- /dev/null +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -0,0 +1,1636 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Attention KV cache modeling.""" + +# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name +import enum +import math +from typing import Tuple + +from tvm import relax as rx +from tvm import tir +from tvm.relax.frontend.nn import Object, Tensor +from tvm.runtime import DataType +from tvm.script import tir as T +from tvm.target import Target + +from .position_embedding import llama_rope_with_position_map, rope_freq +from .tree_attn import tree_attn + + +def get_max_num_threads_per_block(target: Target) -> int: + """ + max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. + We add this method since some targets have both fields and `max_threads_per_block` is larger. + """ + max_num_threads = target.max_num_threads + max_threads_per_block = target.attrs.get("max_threads_per_block", None) + if max_threads_per_block is None: + return max_num_threads + return max(max_num_threads, max_threads_per_block) + + +def check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int): + """ + Check whether max num threads exceeded given a target. + + Parameters + ---------- + bdx: threadIdx.x + bdy: threadIdx.y + bdz: threadIdx.z + gdz: blockIdx.z + """ + max_num_threads_per_block = get_max_num_threads_per_block(target) + + assert ( + bdx * bdy * bdz <= max_num_threads_per_block + ), f"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{max_num_threads_per_block}" + + if str(target.kind) == "webgpu": + # https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez + assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got bdz={bdz}" + assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}" + + +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +class PagedKVCache(Object): # pylint: disable=too-few-public-methods + """The Paged KV Cache used in LLM batching for efficient attention computation.""" + + def attention_with_fused_qkv( + self, + layer_id: int, + qkv: Tensor, + num_qo_heads: int, + attn_score_scaling_factor: float = 1.0, + ) -> Tensor: + """Compute attention with the given fused q/k/v data and in-cache k/v data + on the specified layer. Rotary position embeddings are applied to k/v + within this function. + + - For prefill, the input qkv and output tensor have shape + (1, total_seq_len) for the first two dimensions. + - For decode, the input qkv and output tensor have shape + (batch_size, 1) for the first two dimensions. + - The input qkv have `2 * num_qo_heads + num_kv_heads` at the third dim. + - The output tensor have `num_qo_heads` at the third dim. + - The input qkv and output tensor have `head_dim` at the last dim. + """ + # pylint: disable=protected-access + b, s, _, d = qkv._expr.struct_info.shape + qkv = qkv.reshape(b * s, qkv.shape[2], d) + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.attention_kv_cache_attention_with_fused_qkv", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(attn_score_scaling_factor), + qkv._expr, + ], + out_sinfo=rx.TensorStructInfo((b * s, num_qo_heads, d), qkv.dtype), + ) + ) + ).reshape(b, s, num_qo_heads, d) + + def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: + """Get the in-sequence positions of each slot in the query, + which are needed for applying positional embeddings in some models. + + Parameters + ---------- + total_length : tir.PrimExpr + The summed-up total sequence length of queries in + the batch being forwarded. + + Returns + ------- + q_positions : Tensor + The in-sequence query positions, in shape `(total_length,)` + """ + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_pure_packed( + "vm.builtin.attention_kv_cache_get_query_positions", + self._expr, + sinfo_args=rx.TensorStructInfo((total_length,), "int32"), + ) + ) + ) + + # pylint: enable=protected-access + + +class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods + """Paged KV cache using FlashInfer (CUDA) kernels.""" + + def __init__( # pylint: disable=too-many-locals + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + layer_partition: rx.ShapeExpr, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rope_mode: RopeMode, + rope_scale: int, + rope_theta: int, + rotary_dim: int, + dtype: str, + target: Target, + name: str = "paged_kv_cache", + ) -> None: + """Create a paged KV cache object with FlashInfer kernels. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + prefill_chunk_size : tir.Var + The maximum total sequence length in a prefill. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + support_sliding_window : tir.Var + 0 or 1, denoting whether the KV cache supports sliding window. + It is a symbolic variable whose concrete value is specified + at runtime. + rope_mode : RopeMode + The RoPE mode of the Paged KV cache. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + rope_scale : int + The scale of rotary position embedding. + rope_theta : int + The base of rotary position embedding. + rope_scaling: Dict[str, Any] + The RoPE scaling information dict. + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. + """ + if rope_mode == RopeMode.INLINE: + assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim." + + bb = rx.BlockBuilder.current() + args = [ + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), + layer_partition, + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), + # pylint: disable=line-too-long + # fmt: off + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), + rx.extern("flashinfer.merge_state_in_place"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + # fmt: on + # pylint: enable=line-too-long + ] + super().__init__( + _expr=rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_create", + *args, + sinfo_args=rx.ObjectStructInfo(), + ), + _name=name, + ) + + +class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods + """Paged KV cache using TIR kernels.""" + + def __init__( # pylint: disable=too-many-locals + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + layer_partition: rx.ShapeExpr, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + rope_mode: RopeMode, + head_dim: int, + rope_scale: int, + rope_theta: int, + rotary_dim: int, + dtype: str, + target: Target, + name: str = "paged_kv_cache", + ) -> None: + """Create a paged KV cache object with TIR kernels. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + prefill_chunk_size : tir.Var + The maximum total sequence length in a prefill. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + support_sliding_window : tir.Var + 0 or 1, denoting whether the KV cache supports sliding window. + It is a symbolic variable whose concrete value is specified + at runtime. + layer_partition : rx.ShapeExpr + The KV cache layer partition for pipeline stages. + It is an indptr array, denoting the starting layer of each pipeline stage. + rope_mode : RopeMode + The RoPE mode of the Paged KV cache. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + rope_scale : int + The scale of rotary position embedding. + rope_theta : int + The base of rotary position embedding. + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. + target : Target + The target to build the model to. + """ + + bb = rx.BlockBuilder.current() + args = [ + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), + layer_partition, + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), + # pylint: disable=line-too-long + # fmt: off + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), + bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + # fmt: on + # pylint: enable=line-too-long + ] + super().__init__( + _expr=rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_create_reduced", + *args, + sinfo_args=rx.ObjectStructInfo(), + ), + _name=name, + ) + + +# mypy: disable-error-code="attr-defined,valid-type,no-redef" +# pylint: disable=too-many-locals + + +def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): + """Return the TIR function that appends new k/v data to PagedKVCache.""" + + # pylint: disable=line-too-long + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + num_pages = T.int64() + position_map_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset + ) + for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): + if position_map[global_pos] != T.int32(-1): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + # fmt: on + # pylint: enable=line-too-long + + return tir_kv_cache_transpose_append + + +def _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): + """Return the TIR function that fetches the k/v data on given positions and layer.""" + + # pylint: disable=line-too-long + # fmt: off + @T.prim_func + def tir_kv_cache_debug_get_kv( + var_pages: T.handle, + var_position_map: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + layer_id: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + seqlen = T.SizeVar("num_tokens_including_cache", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + position_map_elem_offset = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset + ) + k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd]) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] # type: ignore[name-defined] + k_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd] + v_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd] + # fmt: on + # pylint: enable=line-too-long + + return tir_kv_cache_debug_get_kv + + +def _rope( + buffer: T.Buffer, + offset: tir.Var, + rotary_dim: int, + theta: tir.Var, + scale: tir.Var, + indices: Tuple[tir.Var, ...], + qkv_dtype="float16", +): + d = indices[-1] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, "float32") + cos = cos_freq * buffer[indices].astype("float32") + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -buffer[indices[:-1] + (d + rotary_dim // 2,)], + buffer[indices[:-1] + (d - rotary_dim // 2,)], + ).astype("float32") + return (cos + sin).astype(qkv_dtype) + + +def _var(dtype): + return T.alloc_buffer((1,), dtype, scope="local") + + +def _causal_mask(causal, row, col, kv_len, qo_len): + return T.if_then_else( + causal > 0, + col < kv_len - qo_len + row + 1, + col < kv_len, + ) + + +def _declare_length_info(var_length_info, batch_size, sliding_window, elem_offset): + return ( + T.match_buffer(var_length_info, (3, batch_size), "int32", elem_offset=elem_offset) + if sliding_window + else T.match_buffer(var_length_info, (batch_size,), "int32", elem_offset=elem_offset) + ) + + +def _get_kv_chunk_len(num_pages, page_size, seq_id, length_info, sliding_window): + if not sliding_window: + return (num_pages - 1) * page_size + length_info[seq_id] + # ((num_pages - 1) * page_size + last_page_len) - sliding_window_offset + sink_size + return ( + (num_pages - 1) * page_size + + length_info[0, seq_id] + - length_info[1, seq_id] + + length_info[2, seq_id] + ) + + +def _get_seq_offset(pos, seq_id, length_info, sliding_window): + if not sliding_window: + return pos + # pos if pos < sink_size else pos - sink_size + sliding_window_offset + return T.if_then_else( + pos < length_info[2, seq_id], + pos, + pos - length_info[2, seq_id] + length_info[1, seq_id], + ) + + +def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target): + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) + + global_symbol = "batch_prefill_paged_kv" + if sliding_window: + global_symbol += "_sliding_window" + + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] + + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), + pages[page_no, 0, by, page_offset, j] + ) + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + V_smem[i, j] = pages[page_no, 1, by, page_offset, j] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_prefill_paged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _attention_decode( + num_kv_heads, + num_qo_heads, + head_dim, + qkv_dtype, + sliding_window: bool, + target: Target, +): + qkv_dtype_bytes = 2 + H_qo = num_qo_heads + H_kv = num_kv_heads + D = head_dim + + THREAD_LIMIT = 512 + TILE_SIZE_PER_BDX = 2 + if target.kind.name == "opencl" and "android" in str(target.host): + THREAD_LIMIT = 256 if H_kv < 8 else 512 + TILE_SIZE_PER_BDX = 1 + max_num_threads_per_block = get_max_num_threads_per_block(target) + thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) + + GROUP_SIZE = H_qo // H_kv + VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) + bdx = D // VEC_SIZE + bdy = GROUP_SIZE + while bdx * bdy > thread_limit and bdy > 1: + bdy //= 2 + gdz = GROUP_SIZE // bdy + threads_per_CTA = max(thread_limit, bdx * bdy) + bdz = threads_per_CTA // (bdx * bdy) + tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 + log2e = math.log2(math.exp(1)) + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) + + global_symbol = "batch_decode_paged_kv" + if sliding_window: + global_symbol += "_sliding_window" + + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_decode_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + Q_handle: T.handle, + pages_handle: T.handle, + page_table_indptr_handle: T.handle, + page_table_values_handle: T.handle, + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + k_rope_pos_offset_handle: T.handle, + q_rope_position_handle: T.handle, + output_handle: T.handle, + lse_handle: T.handle, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) + B = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) + pages = T.match_buffer( + pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype + ) + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) + lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) + + sm_scale = 1.0 / math.sqrt(float(D)) * log2e + + for bx in T.thread_binding(B, thread="blockIdx.x"): + for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + for tz in T.thread_binding(bdz, thread="threadIdx.z"): + with T.block("attn"): + Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") + V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") + O_allreduce = T.alloc_buffer((bdz, bdy, D), "float32", scope="shared") + md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") + S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") + QK_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + m_prev = T.alloc_buffer((1,), "float32", scope="local") + d_prev = T.alloc_buffer((1,), "float32", scope="local") + other_m = T.alloc_buffer((1,), "float32", scope="local") + other_d = T.alloc_buffer((1,), "float32", scope="local") + exp_mprev = T.alloc_buffer((1,), "float32", scope="local") + exp_otherm = T.alloc_buffer((1,), "float32", scope="local") + other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + st_m = T.alloc_buffer((1,), "float32", scope="local") + st_d = T.alloc_buffer((1,), "float32", scope="local") + O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + + by: T.int32 = fused_by_bz % H_kv + bz: T.int32 = fused_by_bz // H_kv + batch_idx: T.int32 = bx + cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] + cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), + 0 + ) + + # init states + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + # load q + for vec in T.vectorized(VEC_SIZE): + Q_local[vec] = T.if_then_else( + rotary_mode == 1, + _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype), + Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] + ) + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): + tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore + tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore + # load KV from global memory to shared memory + for j in T.serial(tile_size_per_bdx): + with T.block("KV_load"): + T.reads() + T.writes() + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), + pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] + ) + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] + else: + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + T.tvm_storage_sync("shared") + # compute QK + m_prev[0] = st_m[0] + for j in T.serial(bdy * tile_size_per_bdx): + # compute S = Q * K * sm_scale + for vec in T.vectorized(VEC_SIZE): + QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * sm_scale + S_reduce_local[0] = 0 + for vec in T.unroll(VEC_SIZE): + S_reduce_local[0] += QK_local[vec] + + with T.block("block_cross_thread"): + T.reads(S_reduce_local[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") + + S_local[j] = -5e4 + if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: + S_local[j] = t0[0] + # update st_m + st_m[0] = T.max(st_m[0], S_local[j]) + + # update st_d, st_O + o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) + st_d[0] *= o_scale + for j in T.serial(bdy * tile_size_per_bdx): + S_local[j] = T.exp2(S_local[j] - st_m[0]) + st_d[0] += S_local[j] + for j in T.vectorized(VEC_SIZE): + O_local[j] *= o_scale + + # load V from shared memory to local memory + # compute O + for j in T.serial(bdy * tile_size_per_bdx): + for vec in T.vectorized(VEC_SIZE): + V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] + for vec in T.vectorized(VEC_SIZE): + O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] + + if bdz > 1: + # allreduce over bdz + for vec in T.vectorized(VEC_SIZE): + O_allreduce[tz, ty, tx * VEC_SIZE + vec] = O_local[vec] + md_allreduce[tz, ty, 0] = st_m[0] + md_allreduce[tz, ty, 1] = st_d[0] + T.tvm_storage_sync("shared") + + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + for j in T.serial(bdz): + m_prev[0] = st_m[0] + d_prev[0] = st_d[0] + other_m[0] = md_allreduce[j, ty, 0] + other_d[0] = md_allreduce[j, ty, 1] + for vec in T.vectorized(VEC_SIZE): + other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] + st_m[0] = T.max(st_m[0], other_m[0]) + st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] + + # normalize O + for vec in T.vectorized(VEC_SIZE): + O_local[vec] /= st_d[0] + + # store O to global memory + for vec in T.vectorized(VEC_SIZE): + output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] + + # store lse to global memory + lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0]) + # fmt: on + # pylint: enable=line-too-long,too-many-branches + return batch_decode_paged_kv + + +def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target): + v_dtype_bytes = 2 + VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) + bdx = head_dim // VEC_SIZE + bdy = num_heads + max_num_threads_per_block = get_max_num_threads_per_block(target) + while bdx * bdy > max_num_threads_per_block and bdy > 1: + bdy //= 2 + gdy = num_heads // bdy + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=1, gdz=1) + + @T.prim_func + def merge_state_inplace( + v: T.handle, + s: T.handle, + v_other: T.handle, + s_other: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + N = T.int32(is_size_var=True) + H = T.int32(is_size_var=True) + D = T.int32(is_size_var=True) + + V = T.match_buffer(v, (N, H, D), v_dtype) + S = T.match_buffer(s, (N, H), "float32") + V_other = T.match_buffer(v_other, (N, H, D), v_dtype) + S_other = T.match_buffer(s_other, (N, H), "float32") + + for bx in T.thread_binding(N, thread="blockIdx.x"): + for by in T.thread_binding(gdy, thread="blockIdx.y"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("merge"): + s_val = _var("float32") + s_other_val = _var("float32") + s_max = _var("float32") + scale = _var("float32") + other_scale = _var("float32") + + v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + + s_val[0] = S[bx, ty + by * bdy] + s_other_val[0] = S_other[bx, ty + by * bdy] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + + # load v + for vec in T.vectorized(VEC_SIZE): + v_vec[vec] = V[bx, ty + by * bdy, tx * VEC_SIZE + vec] + # load v_other + for vec in T.vectorized(VEC_SIZE): + v_other_vec[vec] = V_other[bx, ty + by * bdy, tx * VEC_SIZE + vec] + + # merge + for vec in T.serial(VEC_SIZE): + v_vec[vec] = ( + v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] + ) + + # store v + for vec in T.vectorized(VEC_SIZE): + V[bx, ty + by * bdy, tx * VEC_SIZE + vec] = v_vec[vec] + + # store s + S[bx, ty + by * bdy] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + return merge_state_inplace + + +def _attention_prefill_ragged(h_kv, h_q, d, dtype, target: Target): + # pylint: disable=line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + + # fmt: off + @T.prim_func + def batch_prefill_ragged_kv( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32 + ): + batch_size = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + q_indptr_val: T.int32 = q_indptr[b_idx] + LH_start: T.int32 = tile_id[0] * tile_x + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype), + k[L_kv_base + cur_L, by, j] + ) + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + V_smem[i, j] = v[L_kv_base + cur_L, by, j] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_prefill_ragged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def copy_single_page( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + return copy_single_page + + +def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def compact_kv_copy( + var_pages: T.handle, + var_copy_length_indptr: T.handle, + var_copy_src_dst_pos: T.handle, + batch_size: T.int32, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + total_copy_length = T.int32() + copy_length_indptr_elem_offset = T.int32() + copy_src_dst_pos_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + copy_length_indptr = T.match_buffer( + var_copy_length_indptr, + (batch_size + 1,), + "int32", + elem_offset=copy_length_indptr_elem_offset, + ) + copy_src_dst_pos = T.match_buffer( + var_copy_src_dst_pos, + (2, total_copy_length), + "int32", + elem_offset=copy_src_dst_pos_elem_offset, + ) + + with T.block("root"): + for bhd_o in T.thread_binding( + (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): + b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) + h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads + d: T.int32 = (bhd_o * tx + bhd_i) % head_dim + if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ + src_pos // 16, 0, h, src_pos % 16, d + ] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ + src_pos // 16, 1, h, src_pos % 16, d + ] + + return compact_kv_copy diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py new file mode 100644 index 000000000000..b224ce04c597 --- /dev/null +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Operators for positional embeddings, e.g. RoPE.""" + +from typing import Optional, Tuple + +from tvm import tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T + +# pylint: disable=invalid-name + + +def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): + """Compute the inverse frequency of RoPE and then return the cosine and sine of it. + + Parameters + ---------- + s : tir.Var + The position index. + + d : tir.Var + The dimension index. + + d_range : int + The maximum dimension index. + + theta : float + The theta value in RoPE, which controls the frequency. + + dtype : str + The data type of the output. + + Returns + ------- + cos_freq : Tensor + The cosine of the inverse frequency. + + sin_freq : Tensor + The sine of the inverse frequency. + """ + freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) + cos_freq = tir.cos(freq).astype(dtype) + sin_freq = tir.sin(freq).astype(dtype) + return cos_freq, sin_freq + + +# mypy: disable-error-code="attr-defined" + + +def llama_rope( # pylint: disable=too-many-arguments + qkv: Tensor, + total_seq_len: tir.Var, + theta: float, + num_q_heads: int, + num_kv_heads: int, + scale: float = 1.0, + rotary_dim: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """Llama-style RoPE. Given a fused QKV tensor, it returns three tensors, Q, K, and V, where Q + and K are rotated by RoPE while V remains unchanged. + + Parameters + ---------- + qkv : Tensor + The fused QKV tensor of shape: [batch_size, seq_len, #q_heads + #kv_heads * 2, head_dim] + + total_seq_len : tir.Var + The total sequence length after being concatenated with KVCache. It is used to compute the + offset of RoPE. + + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + rotary_dim : Optional[int] + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + + Returns + ------- + q : Tensor + The query tensor of shape [batch_size, seq_len, #q_heads, head_dim] w/ RoPE applied + + k : Tensor + The key tensor of shape [batch_size, seq_len, #kv_heads, head_dim] w/ RoPE applied + + v : Tensor + The value tensor of shape [batch_size, seq_len, #kv_heads, head_dim] w/o RoPE applied + """ + _, _, fused_heads, head_dim = qkv.shape + assert fused_heads == num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + dtype = qkv.dtype + scale = tir.const(scale, dtype) + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + b: tir.Var, + s: tir.Var, + h: tir.Var, + d: tir.Var, + offset: tir.Var, + ): + cos_freq, sin_freq = rope_freq((s + offset) * scale, d, rotary_dim, theta, dtype) + cos = cos_freq * x[b, s, h, d] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[b, s, h, d + rotary_dim // 2], + x[b, s, h, d - rotary_dim // 2], + ) + return cos + sin + + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + total_seq_len: T.int64, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": T.bool(True), + } + ) + batch_size = T.int64() + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype) + for iters in T.grid(batch_size, seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + b, s, h, d = T.axis.remap("SSSS", iters) + if h < num_q_heads: + q[b, s, h, d] = T.if_then_else( + d < rotary_dim, + _rope(qkv, b, s, h, d, total_seq_len - seq_len), + qkv[b, s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[b, s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope(qkv, b, s, h, d, total_seq_len - seq_len), + qkv[b, s, h, d], + ) + else: + v[b, s, h - (num_q_heads + num_kv_heads), d] = qkv[b, s, h, d] + + b, s, _, _ = qkv.shape + return op.tensor_ir_op( # pylint: disable=no-member + fused_rope, + "llama_rope", + args=[qkv, total_seq_len], + out=( + Tensor.placeholder((b, s, num_q_heads, head_dim), dtype), + Tensor.placeholder((b, s, num_kv_heads, head_dim), dtype), + Tensor.placeholder((b, s, num_kv_heads, head_dim), dtype), + ), + ) + + +def llama_rope_with_position_map( # pylint: disable=too-many-arguments + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + rotary_dim: Optional[int] = None, +): + """Return the TIR function that computes Llama-style RoPE with q position map. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + fused_heads = num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + scale = tir.const(scale, "float32") + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + pos: tir.Var, + ): + cos_freq, sin_freq = rope_freq(pos * scale, d, rotary_dim, theta, "float32") + cos = cos_freq * x[s, h, d].astype("float32") + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[s, h, d + rotary_dim // 2], + x[s, h, d - rotary_dim // 2], + ).astype("float32") + return (cos + sin).astype(dtype) + + @T.prim_func + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + apply_rope: T.int32, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": T.bool(True), + } + ) + seq_len = T.int64() + position_map_elem_offset = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + return fused_rope diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py new file mode 100644 index 000000000000..486491dbf2c6 --- /dev/null +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -0,0 +1,411 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +"""Operators for tree attention.""" + +import math +from typing import Tuple + +from tvm import tir +from tvm.runtime import DataType +from tvm.script import tir as T +from tvm.target import Target + +from .position_embedding import rope_freq + +# mypy: disable-error-code="attr-defined,valid-type,no-redef" +# pylint: disable=too-many-statements,too-many-locals,too-many-arguments + + +def _var(dtype): + return T.alloc_buffer((1,), dtype, scope="local") + + +def _rope( + buffer: T.Buffer, + offset: tir.Var, + rotary_dim: int, + theta: tir.Var, + scale: tir.Var, + indices: Tuple[tir.Var, ...], + qkv_dtype="float16", +): + d = indices[-1] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) + cos = cos_freq * buffer[indices] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -buffer[indices[:-1] + (d + rotary_dim // 2,)], + buffer[indices[:-1] + (d - rotary_dim // 2,)], + ) + return cos + sin + + +def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): + return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) + + +def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument + """Generate tree attention kernel for batched tree attention. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + + # fmt: off + @T.prim_func + def batch_tree_attn( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case + var_q_rope_position: T.handle, # [total_q_len] + var_mn_indptr: T.handle, # [batch_size + 1] + var_mask: T.handle, # [mn_indptr[batch_size]] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + batch_size: T.int32, + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + mn_indptr_elem_offset = T.int32(is_size_var=True) + mask_elem_offset = T.int32(is_size_var=True) + tree_size = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) + mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) + mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("KV_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_base + L_kv_start + i + if L_kv_start + i < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype), + k[cur_L, by, j] + ) + V_smem[i, j] = v[cur_L, by, j] + else: + K_smem[i, j] = 0.0 + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_tree_attn) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("KV_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 3c85a13e4cfc..96a2438505b2 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -16,7 +16,6 @@ # under the License. import enum import itertools -import math from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -25,12 +24,20 @@ import tvm import tvm.testing -from tvm import DataType from tvm import dlight as dl -from tvm import tir +from tvm.relax.frontend.nn.llm.kv_cache import ( + _attention_decode, + _attention_prefill, + _attention_prefill_ragged, + _compact_kv_copy, + _copy_single_page, + _kv_cache_debug_get_kv, + _kv_cache_transpose_append, + _merge_state_inplace, + llama_rope_with_position_map, + tree_attn, +) from tvm.runtime import ShapeTuple -from tvm.script import tir as T -from tvm.target import Target reserved_nseq = 32 maximum_total_seq_length = 2048 @@ -104,14 +111,14 @@ def set_global_func(head_dim, dtype): target = tvm.target.Target("cuda") builts = [] for tir_func in [ - kv_cache_transpose_append(head_dim, dtype), - copy_cache(head_dim, dtype), + _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), + _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype), _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, False, target), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, False, target), _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype, target), - _attention_prefill_with_tree_mask(num_kv_heads, num_qo_heads, head_dim, dtype, target), + tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, target), _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype @@ -887,1748 +894,6 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) -def kv_cache_transpose_append(head_dim, dtype): - # undefined vars used - @T.prim_func(check_well_formed=False) - def _kv_cache_transpose_append( - var_pages: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - var_position_map: T.handle, - ): - ntoken = T.SizeVar("ntoken", "int32") - num_pages = T.int32() - position_map_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, 16, head_dim), dtype) - k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset - ) - - for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim): - if position_map[global_pos] != T.int32(-1): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[ - vgpos, vh, vf - ] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[ - vgpos, vh, vf - ] - - return _kv_cache_transpose_append - - -def copy_cache(head_dim, dtype): - # undefined vars used - @T.prim_func(check_well_formed=False) - def _copy_cache( - var_pages: T.handle, - var_position_map: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - layer_id: T.int64, - ): - num_kv_heads = T.int64() - seqlen = T.SizeVar("seqlen", "int64") - page_size = T.int64() - num_pages = T.int64() - position_map_elem_offset = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset - ) - k_data = T.match_buffer(var_k_data, (num_layers, seqlen, num_kv_heads, head_dim), dtype) - v_data = T.match_buffer(var_v_data, (num_layers, seqlen, num_kv_heads, head_dim), dtype) - - for p, h, d in T.grid(seqlen, num_kv_heads, head_dim): - with T.block("copy0"): - vp, vh, vd = T.axis.remap("SSS", [p, h, d]) - T.reads( - position_map[vp], - pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd], - ) - T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) - position: T.int64 = T.Cast("int64", position_map[vp]) - k_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd - ] - v_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd - ] - - return _copy_cache - - -def llama_rope_with_position_map( # pylint: disable=too-many-arguments - theta: float, - scale: float, - head_dim: int, - num_q_heads: int, - num_kv_heads: int, - dtype: float = "float16", - rotary_dim: int = None, -): - fused_heads = num_q_heads + num_kv_heads * 2 - if rotary_dim is None: - rotary_dim = head_dim - scale = tir.const(scale, dtype) - - def _rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): - freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) - cos_freq = tir.cos(freq).astype(dtype) - sin_freq = tir.sin(freq).astype(dtype) - return cos_freq, sin_freq - - def _rope( # pylint: disable=too-many-arguments - x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - pos: tir.Var, - ): - cos_freq, sin_freq = _rope_freq(pos * scale, d, rotary_dim, theta, dtype) - cos = cos_freq * x[s, h, d] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[s, h, d + rotary_dim // 2], - x[s, h, d - rotary_dim // 2], - ) - return cos + sin - - # undefined vars used - @T.prim_func(private=True, check_well_formed=False) - def fused_rope( # pylint: disable=too-many-locals - var_qkv: T.handle, - var_position_map: T.handle, - var_q: T.handle, - var_k: T.handle, - var_v: T.handle, - apply_rope: T.int32, - ): - T.func_attr( - { - "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": T.bool(True), - } - ) - seq_len = T.int64() - position_map_elem_offset = T.int64() - qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) - q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) - k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) - v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset - ) - for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): - s, h, d = T.axis.remap("SSS", iters) - if h < num_q_heads: - q[s, h, d] = T.if_then_else( - apply_rope > 0 and d < rotary_dim, - _rope(qkv, s, h, d, position_map[s]), - qkv[s, h, d], - ) - elif h < num_q_heads + num_kv_heads: - k[s, h - num_q_heads, d] = T.if_then_else( - apply_rope > 0 and d < rotary_dim, - _rope(qkv, s, h, d, position_map[s]), - qkv[s, h, d], - ) - else: - v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] - - return fused_rope - - -def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): - """Compute the inverse frequency of RoPE and then return the cosine and sine of it. - - Parameters - ---------- - s : tir.Var - The position index. - - d : tir.Var - The dimension index. - - d_range : int - The maximum dimension index. - - theta : float - The theta value in RoPE, which controls the frequency. - - dtype : str - The data type of the output. - - Returns - ------- - cos_freq : Tensor - The cosine of the inverse frequency. - - sin_freq : Tensor - The sine of the inverse frequency. - """ - freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) - cos_freq = tir.cos(freq).astype(dtype) - sin_freq = tir.sin(freq).astype(dtype) - return cos_freq, sin_freq - - -def _rope( # pylint: disable=too-many-arguments - buffer: T.Buffer, - offset: tir.Var, - rotary_dim: int, - theta: tir.Var, - scale: tir.Var, - indices: Tuple[tir.Var, ...], - qkv_dtype="float16", -): - d = indices[-1] - cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) - cos = cos_freq * buffer[indices] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -buffer[indices[:-1] + (d + rotary_dim // 2,)], - buffer[indices[:-1] + (d - rotary_dim // 2,)], - ) - return cos + sin - - -def _var(dtype): - return T.alloc_buffer((1,), dtype, scope="local") - - -def _causal_mask(causal, row, col, kv_len, qo_len): - return T.if_then_else( - causal > 0, - col < kv_len - qo_len + row + 1, - col < kv_len, - ) - - -def _declare_length_info(var_length_info, batch_size, sliding_window, elem_offset): - return ( - T.match_buffer(var_length_info, (3, batch_size), "int32", elem_offset=elem_offset) - if sliding_window - else T.match_buffer(var_length_info, (batch_size,), "int32", elem_offset=elem_offset) - ) - - -def _get_kv_chunk_len(num_pages, page_size, seq_id, length_info, sliding_window): - if not sliding_window: - return (num_pages - 1) * page_size + length_info[seq_id] - else: - # ((num_pages - 1) * page_size + last_page_len) - sliding_window_offset + sink_size - return ( - (num_pages - 1) * page_size - + length_info[0, seq_id] - - length_info[1, seq_id] - + length_info[2, seq_id] - ) - - -def _get_seq_offset(pos, seq_id, length_info, sliding_window): - if not sliding_window: - return pos - else: - # pos if pos < sink_size else pos - sink_size + sliding_window_offset - return T.if_then_else( - pos < length_info[2, seq_id], - pos, - pos - length_info[2, seq_id] + length_info[1, seq_id], - ) - - -def get_max_num_threads_per_block(target: Target): - """ - max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. - We add this method since some targets have both fields and `max_threads_per_block` is larger. - """ - max_num_threads = target.max_num_threads - max_threads_per_block = target.attrs.get("max_threads_per_block", None) - if max_threads_per_block is None: - return max_num_threads - return max(max_num_threads, max_threads_per_block) - - -def _attention_prefill( - h_kv, h_q, d, dtype, sliding_window: bool, target: Target -): # pylint: disable=unused-argument - # pylint: disable=invalid-name - NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - bdx = 32 - num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - L_per_cta = tile_x // group_size - - # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): - tile_z = 8 - num_warps = 2 - - # undefined vars used - # pylint: disable=line-too-long,too-many-arguments,too-many-branches - # fmt: off - @T.prim_func(check_well_formed=False) - def batch_prefill_paged_kv( - _0: T.int32, # pylint: disable=unused-argument - var_q: T.handle, # [total_len, h_q, d] - var_q_indptr: T.handle, # [batch_size + 1] - var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] - var_page_indptr: T.handle, # [batch_size + 1] - var_page_values: T.handle, # [nnz_pages] - var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] - var_k_rope_pos_offset: T.handle, # [b] - var_q_rope_position: T.handle, # [total_len] - var_output: T.handle, # [total_len, h_q, d] - var_lse: T.handle, # [total_len, h_q] - causal: T.int32, - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - ): - batch_size = T.int32(is_size_var=True) - total_len = T.int32(is_size_var=True) - nnz_pages = T.int32(is_size_var=True) - max_num_pages = T.int32(is_size_var=True) - q_indptr_elem_offset = T.int32(is_size_var=True) - page_indptr_elem_offset = T.int32(is_size_var=True) - page_values_elem_offset = T.int32(is_size_var=True) - k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - q_rope_position_elem_offset = T.int32(is_size_var=True) - length_info_elem_offset = T.int32(is_size_var=True) - - q = T.match_buffer(var_q, (total_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) - page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) - page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) - k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) - q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) - output = T.match_buffer(var_output, (total_len, h_q, d), dtype) - lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable - # The length information of the sequences. - # - It is in shape `(3, batch_size)` when sliding window is enabled. - # For a sequence "i", location - # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), - # - "(1, i)" is the starting offset of the sliding window in the seq, - # - "(2, i)" is the attn sink length of the sequence. - # - It is in shape `(batch_size,)` when sliding window is disabled, - # denoting the "last_page_len". - length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) - - # kernel code - for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): - for lby in T.thread_binding(h_kv, thread="blockIdx.y"): - for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): - bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) - T.reads() - T.writes() - tile_id = _var("int32") - batch_idx = _var("int32") - batch_tiles = _var("int32") - batch_rows = _var("int32") - iterator = _var("int32") - kv_chunk_len = _var("int32") - - Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") - K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") - - S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") - O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") - - m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - - m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - - ## get tile_no, batch_idx, batch_tiles, batch_rows - tile_id[0] = bx - batch_idx[0] = 0 - batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - while T.tvm_thread_invariant(batch_idx[0] < batch_size): - # advance to next tile - while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: - tile_id[0] -= batch_tiles[0] - batch_idx[0] += 1 - if batch_idx[0] < batch_size: - b_idx: T.int32 = batch_idx[0] - batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - - if T.tvm_thread_invariant(batch_idx[0] < batch_size): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x - q_indptr_val: T.int32 = q_indptr[b_idx] - - cur_page_indptr_begin: T.int32 = page_indptr[b_idx] - cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] - kv_chunk_len[0] = T.if_then_else( - cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), - 0 - ) - T.tvm_storage_sync("shared") - - # init states - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - m_smem[row] = -5e4 - d_smem[row] = 1.0 - - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): - i, j = T.axis.remap("SS", [li, lj]) - O_local[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Load Q from gmem to smem - for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): - i, j = T.axis.remap("SS", [li, lj]) - T.reads() - T.writes() - cur_L = q_indptr_val + (LH_start + i) // group_size - cur_H_qo = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - Q_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), - q[cur_L, cur_H_qo, j] - ) - else: - Q_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): - L_kv_start: T.int32 = iterator * tile_z - for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - K_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), - pages[page_no, 0, by, page_offset, j] - ) - else: - K_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - for lz, ly in T.grid(tile_z, tile_y): - with T.block("V_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - V_smem[i, j] = pages[page_no, 1, by, page_offset, j] - else: - V_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Compute S - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale - T.tvm_storage_sync("shared") - for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): - i, j = T.axis.remap("SS", [li, lj]) - S_smem[i, j] = S_local[i, j] - T.tvm_storage_sync("shared") - - # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update1"): - m_prev[i] = m_smem[row] - m_new[i] = m_smem[row] - # mask out of kv_chunk_len S - row_: T.int32 = (LH_start + row) // group_size - for j in T.serial(tile_z): - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - m_new[i] = T.max(m_new[i], S_smem[row, j]) - d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): - for j in T.serial(tile_z): - # this is to avoid sync inside condition branch - if row < tile_x: - row_: T.int32 = (LH_start + row) // group_size - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) - else: - S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update"): - for j in T.serial(tile_z): - d_new[i] += S_smem[row, j] - m_smem[row] = m_new[i] - d_smem[row] = d_new[i] - m_prev_smem[row] = m_prev[i] - T.tvm_storage_sync("shared") - - # Update O - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") - - # Store O from smem to gmem - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): - i, j = T.axis.remap("SS", [li, lj]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] - - # Store LSE to gmem - for li in T.grid(tile_x): - with T.block("lse_store"): - i = T.axis.remap("S", [li]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) - - # move to next tile - tile_id[0] += NUM_BLKS - # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches - sch = tir.Schedule(batch_prefill_paged_kv) - - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y - - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - sch.decompose_reduction(block, ty) - - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) - apply_to_md(sch, sch.get_block("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -def _attention_decode( - num_kv_heads, - num_qo_heads, - head_dim, - qkv_dtype, - sliding_window: bool, - target: Target, # pylint: disable=unused-argument -): - # pylint: disable=invalid-name - qkv_dtype_bytes = 2 - H_qo = num_qo_heads - H_kv = num_kv_heads - D = head_dim - - THREAD_LIMIT = 512 - TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 64 - TILE_SIZE_PER_BDX = 1 - max_num_threads_per_block = get_max_num_threads_per_block(target) - thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) - - GROUP_SIZE = H_qo // H_kv - VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) - bdx = D // VEC_SIZE - bdy = GROUP_SIZE - while bdx * bdy > thread_limit and bdy > 1: - bdy //= 2 - gdz = GROUP_SIZE // bdy - threads_per_CTA = max(thread_limit, bdx * bdy) - bdz = threads_per_CTA // (bdx * bdy) - tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 - log2e = math.log2(math.exp(1)) - - # undefined vars used - # pylint: disable=line-too-long,too-many-arguments,too-many-branches - # fmt: off - @T.prim_func(check_well_formed=False) - def batch_decode_paged_kv( - _0: T.int32, # pylint: disable=unused-argument - Q_handle: T.handle, - pages_handle: T.handle, - page_table_indptr_handle: T.handle, - page_table_values_handle: T.handle, - var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] - k_rope_pos_offset_handle: T.handle, - q_rope_position_handle: T.handle, - output_handle: T.handle, - lse_handle: T.handle, - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - ): - T.func_attr({"tir.is_scheduled": 1}) - B = T.int32(is_size_var=True) - nnz_pages = T.int32(is_size_var=True) - max_num_pages = T.int32(is_size_var=True) - page_indptr_elem_offset = T.int32(is_size_var=True) - page_values_elem_offset = T.int32(is_size_var=True) - k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - q_rope_position_elem_offset = T.int32(is_size_var=True) - length_info_elem_offset = T.int32(is_size_var=True) - - Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) - pages = T.match_buffer( - pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype - ) - page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) - page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) - k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset) - q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset) - output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) - lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable - # The length information of the sequences. - # - It is in shape `(3, batch_size)` when sliding window is enabled. - # For a sequence "i", location - # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), - # - "(1, i)" is the starting offset of the sliding window in the seq, - # - "(2, i)" is the attn sink length of the sequence. - # - It is in shape `(batch_size,)` when sliding window is disabled, - # denoting the "last_page_len". - length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) - - sm_scale = 1.0 / math.sqrt(float(D)) * log2e - - for bx in T.thread_binding(B, thread="blockIdx.x"): - for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): - for ty in T.thread_binding(bdy, thread="threadIdx.y"): - for tx in T.thread_binding(bdx, thread="threadIdx.x"): - for tz in T.thread_binding(bdz, thread="threadIdx.z"): - with T.block("attn"): - Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") - K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") - V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") - O_allreduce = T.alloc_buffer((bdz, bdy, D), "float32", scope="shared") - md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") - S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") - t0 = T.alloc_buffer((1,), "float32", scope="local") - - S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") - K_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - m_prev = T.alloc_buffer((1,), "float32", scope="local") - d_prev = T.alloc_buffer((1,), "float32", scope="local") - other_m = T.alloc_buffer((1,), "float32", scope="local") - other_d = T.alloc_buffer((1,), "float32", scope="local") - other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") - st_m = T.alloc_buffer((1,), "float32", scope="local") - st_d = T.alloc_buffer((1,), "float32", scope="local") - O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") - - by: T.int32 = fused_by_bz % H_kv - bz: T.int32 = fused_by_bz // H_kv - batch_idx: T.int32 = bx - cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] - cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] - kv_chunk_len[0] = T.if_then_else( - cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), - 0 - ) - - # init states - st_m[0] = -5e4 - st_d[0] = 1.0 - for vec in T.vectorized(VEC_SIZE): - O_local[vec] = 0.0 - - # load q - for vec in T.vectorized(VEC_SIZE): - Q_local[vec] = T.if_then_else( - rotary_mode == 1, - _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype), - Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] - ) - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): - tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore - tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore - # load K from global memory to shared memory - for j in T.serial(tile_size_per_bdx): - with T.block("K_load"): - T.reads() - T.writes() - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( - rotary_mode == 1, - _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), - pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] - ) - else: - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 - T.tvm_storage_sync("shared") - # load V from global memory to shared memory - for j in T.serial(tile_size_per_bdx): - with T.block("V_load"): - T.reads() - T.writes() - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): - V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] - else: - for vec in T.vectorized(VEC_SIZE): - V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 - T.tvm_storage_sync("shared") - # compute QK - m_prev[0] = st_m[0] - for j in T.serial(bdy * tile_size_per_bdx): - # load K from shared memory to local memory - for vec in T.vectorized(VEC_SIZE): - K_local[vec] = K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] - # compute S = Q * K * sm_scale - S_reduce_local[0] = 0 - for vec in T.serial(VEC_SIZE): - S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale - - with T.block("block_cross_thread"): - T.reads(S_reduce_local[0]) - T.writes(t0[0]) - T.attr( - T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), - "reduce_scope", - T.reinterpret("handle", T.uint64(0)), - ) - T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") - - S_local[j] = -5e4 - if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: - S_local[j] = t0[0] - # update st_m - st_m[0] = T.max(st_m[0], S_local[j]) - - # update st_d, st_O - o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) - st_d[0] *= o_scale - for j in T.serial(bdy * tile_size_per_bdx): - S_local[j] = T.exp2(S_local[j] - st_m[0]) - st_d[0] += S_local[j] - for j in T.vectorized(VEC_SIZE): - O_local[j] *= o_scale - - # load V from shared memory to local memory - # compute O - for j in T.serial(bdy * tile_size_per_bdx): - for vec in T.vectorized(VEC_SIZE): - V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] - for vec in T.vectorized(VEC_SIZE): - O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] - - if bdz > 1: - # allreduce over bdz - for vec in T.vectorized(VEC_SIZE): - O_allreduce[tz, ty, tx * VEC_SIZE + vec] = O_local[vec] - md_allreduce[tz, ty, 0] = st_m[0] - md_allreduce[tz, ty, 1] = st_d[0] - T.tvm_storage_sync("shared") - - st_m[0] = -5e4 - st_d[0] = 1.0 - for vec in T.vectorized(VEC_SIZE): - O_local[vec] = 0.0 - - for j in T.serial(bdz): - m_prev[0] = st_m[0] - d_prev[0] = st_d[0] - other_m[0] = md_allreduce[j, ty, 0] - other_d[0] = md_allreduce[j, ty, 1] - for vec in T.vectorized(VEC_SIZE): - other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] - st_m[0] = T.max(st_m[0], other_m[0]) - st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) - for vec in T.serial(VEC_SIZE): - O_local[vec] = O_local[vec] * T.exp2(m_prev[0] - st_m[0]) + other_o[vec] * T.exp2(other_m[0] - st_m[0]) - - # normalize O - for vec in T.serial(VEC_SIZE): - O_local[vec] /= st_d[0] - - # store O to global memory - for vec in T.vectorized(VEC_SIZE): - output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] - - # store lse to global memory - lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0]) - # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches - return batch_decode_paged_kv - - -def _attention_prefill_ragged( - h_kv, h_q, d, dtype, target: Target -): # pylint: disable=unused-argument - # pylint: disable=invalid-name,line-too-long - NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - bdx = 32 - num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - - # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): - tile_z = 8 - num_warps = 2 - - # undefined vars used - # fmt: off - @T.prim_func(check_well_formed=False) - def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branches - var_q: T.handle, # [total_len, h_q, d] - var_q_indptr: T.handle, # [batch_size + 1] - var_k: T.handle, # [total_len, h_kv, d] - var_v: T.handle, # [total_len, h_kv, d] - var_kv_indptr: T.handle, # [batch_size + 1] - var_q_rope_position: T.handle, # [total_q_len] - var_k_rope_pos_offset: T.handle, # [b] - var_output: T.handle, # [total_len, h_q, d] - var_lse: T.handle, # [total_len, h_q] - causal: T.int32, - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32 - ): - batch_size = T.int32(is_size_var=True) - qo_len = T.int32(is_size_var=True) - kv_len = T.int32(is_size_var=True) - q_indptr_elem_offset = T.int32(is_size_var=True) - kv_indptr_elem_offset = T.int32(is_size_var=True) - q_rope_position_elem_offset = T.int32(is_size_var=True) - k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - - q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) - v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) - kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) - q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) - k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) - output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) - lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable - - # kernel code - for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): - for lby in T.thread_binding(h_kv, thread="blockIdx.y"): - for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): - bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) - T.reads() - T.writes() - tile_id = _var("int32") - batch_idx = _var("int32") - batch_tiles = _var("int32") - batch_rows = _var("int32") - iterator = _var("int32") - kv_chunk_len = _var("int32") - - Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") - K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") - - S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") - O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") - - m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - - m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - - ## get tile_no, batch_idx, batch_tiles, batch_rows - tile_id[0] = bx - batch_idx[0] = 0 - batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - while T.tvm_thread_invariant(batch_idx[0] < batch_size): - # advance to next tile - while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: - tile_id[0] -= batch_tiles[0] - batch_idx[0] += 1 - if batch_idx[0] < batch_size: - b_idx: T.int32 = batch_idx[0] - batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - - if T.tvm_thread_invariant(batch_idx[0] < batch_size): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x - q_indptr_val: T.int32 = q_indptr[b_idx] - - kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] - T.tvm_storage_sync("shared") - - # init states - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - m_smem[row] = -5e4 - d_smem[row] = 1.0 - - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): - i, j = T.axis.remap("SS", [li, lj]) - O_local[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Load Q from gmem to smem - for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): - i, j = T.axis.remap("SS", [li, lj]) - T.reads() - T.writes() - cur_L = q_indptr_val + (LH_start + i) // group_size - cur_H_qo = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - Q_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), - q[cur_L, cur_H_qo, j] - ) - else: - Q_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): - L_kv_start: T.int32 = iterator * tile_z - L_kv_base: T.int32 = kv_indptr[b_idx] - for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - K_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype), - k[L_kv_base + cur_L, by, j] - ) - else: - K_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - for lz, ly in T.grid(tile_z, tile_y): - with T.block("V_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - V_smem[i, j] = v[L_kv_base + cur_L, by, j] - else: - V_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Compute S - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale - T.tvm_storage_sync("shared") - for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): - i, j = T.axis.remap("SS", [li, lj]) - S_smem[i, j] = S_local[i, j] - T.tvm_storage_sync("shared") - - # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update1"): - m_prev[i] = m_smem[row] - m_new[i] = m_smem[row] - # mask out of kv_chunk_len S - row_: T.int32 = (LH_start + row) // group_size - for j in T.serial(tile_z): - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - m_new[i] = T.max(m_new[i], S_smem[row, j]) - d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): - for j in T.serial(tile_z): - # this is to avoid sync inside condition branch - if row < tile_x: - row_: T.int32 = (LH_start + row) // group_size - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) - else: - S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update"): - for j in T.serial(tile_z): - d_new[i] += S_smem[row, j] - m_smem[row] = m_new[i] - d_smem[row] = d_new[i] - m_prev_smem[row] = m_prev[i] - T.tvm_storage_sync("shared") - - # Update O - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") - - # Store O from smem to gmem - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): - i, j = T.axis.remap("SS", [li, lj]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] - - # Store LSE to gmem - for li in T.grid(tile_x): - with T.block("lse_store"): - i = T.axis.remap("S", [li]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) - - # move to next tile - tile_id[0] += NUM_BLKS - # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches - sch = tir.Schedule(batch_prefill_ragged_kv) - - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y - - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - sch.decompose_reduction(block, ty) - - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) - - apply_to_md(sch, sch.get_block("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): - return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) - - -def _attention_prefill_with_tree_mask( - h_kv, h_q, d, dtype, target: Target -): # pylint: disable=unused-argument - # pylint: disable=invalid-name,line-too-long - NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - bdx = 32 - num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - L_per_cta = tile_x // group_size - - # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): - tile_z = 8 - num_warps = 2 - - # fmt: off - @T.prim_func - def batch_tree_attn( # pylint: disable=too-many-branches - var_q: T.handle, # [total_len, h_q, d] - var_q_indptr: T.handle, # [batch_size + 1] - var_k: T.handle, # [total_len, h_kv, d] - var_v: T.handle, # [total_len, h_kv, d] - var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case - var_q_rope_position: T.handle, # [total_q_len] - var_mn_indptr: T.handle, # [batch_size + 1] - var_mask: T.handle, # [mn_indptr[batch_size]] - var_output: T.handle, # [total_len, h_q, d] - var_lse: T.handle, # [total_len, h_q] - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - batch_size: T.int32, - ): - qo_len = T.int32(is_size_var=True) - kv_len = T.int32(is_size_var=True) - q_indptr_elem_offset = T.int32(is_size_var=True) - kv_indptr_elem_offset = T.int32(is_size_var=True) - q_rope_position_elem_offset = T.int32(is_size_var=True) - mn_indptr_elem_offset = T.int32(is_size_var=True) - mask_elem_offset = T.int32(is_size_var=True) - tree_size = T.int32(is_size_var=True) - - q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) - v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) - kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) - q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) - mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) - mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) - output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) - lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable - - # kernel code - for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): - for lby in T.thread_binding(h_kv, thread="blockIdx.y"): - for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): - bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) - T.reads() - T.writes() - tile_id = _var("int32") - batch_idx = _var("int32") - batch_tiles = _var("int32") - batch_rows = _var("int32") - iterator = _var("int32") - kv_chunk_len = _var("int32") - - Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") - K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") - - S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") - O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") - - m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - - m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - - ## get tile_no, batch_idx, batch_tiles, batch_rows - tile_id[0] = bx - batch_idx[0] = 0 - batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - while T.tvm_thread_invariant(batch_idx[0] < batch_size): - # advance to next tile - while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: - tile_id[0] -= batch_tiles[0] - batch_idx[0] += 1 - if batch_idx[0] < batch_size: - b_idx: T.int32 = batch_idx[0] - batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - - if T.tvm_thread_invariant(batch_idx[0] < batch_size): - b_idx: T.int32 = batch_idx[0] - LH_start: T.int32 = tile_id[0] * tile_x - q_indptr_val: T.int32 = q_indptr[b_idx] - - kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] - T.tvm_storage_sync("shared") - - # init states - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - m_smem[row] = -5e4 - d_smem[row] = 1.0 - - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): - i, j = T.axis.remap("SS", [li, lj]) - O_local[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Load Q from gmem to smem - for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): - i, j = T.axis.remap("SS", [li, lj]) - T.reads() - T.writes() - cur_L = q_indptr_val + (LH_start + i) // group_size - cur_H_qo = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - Q_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), - q[cur_L, cur_H_qo, j] - ) - else: - Q_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): - L_kv_start: T.int32 = iterator * tile_z - L_kv_base: T.int32 = kv_indptr[b_idx] - for lz, ly in T.grid(tile_z, tile_y): - with T.block("KV_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_base + L_kv_start + i - if L_kv_start + i < kv_chunk_len[0]: - K_smem[i, j] = T.if_then_else( - rotary_mode == 1, - _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype), - k[cur_L, by, j] - ) - V_smem[i, j] = v[cur_L, by, j] - else: - K_smem[i, j] = 0.0 - V_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Compute S - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale - T.tvm_storage_sync("shared") - for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): - i, j = T.axis.remap("SS", [li, lj]) - S_smem[i, j] = S_local[i, j] - T.tvm_storage_sync("shared") - - # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update1"): - m_prev[i] = m_smem[row] - m_new[i] = m_smem[row] - # mask out of kv_chunk_len S - row_: T.int32 = (LH_start + row) // group_size - for j in T.serial(tile_z): - if _tree_mask( - row=row_, - col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], - kv_len=kv_chunk_len[0]): - m_new[i] = T.max(m_new[i], S_smem[row, j]) - d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): - for j in T.serial(tile_z): - # this is to avoid sync inside condition branch - if row < tile_x: - row_: T.int32 = (LH_start + row) // group_size - if _tree_mask( - row=row_, - col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], - kv_len=kv_chunk_len[0]): - S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) - else: - S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update"): - for j in T.serial(tile_z): - d_new[i] += S_smem[row, j] - m_smem[row] = m_new[i] - d_smem[row] = d_new[i] - m_prev_smem[row] = m_prev[i] - T.tvm_storage_sync("shared") - - # Update O - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") - - # Store O from smem to gmem - for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): - i, j = T.axis.remap("SS", [li, lj]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] - - # Store LSE to gmem - for li in T.grid(tile_x): - with T.block("lse_store"): - i = T.axis.remap("S", [li]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) - - # move to next tile - tile_id[0] += NUM_BLKS - # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-branches - sch = tir.Schedule(batch_tree_attn) - - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y - - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - sch.decompose_reduction(block, ty) - - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("KV_load")) - - apply_to_md(sch, sch.get_block("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -def _merge_state_inplace( - num_heads, head_dim, v_dtype, target: Target -): # pylint: disable=unused-argument - # pylint: disable=invalid-name - v_dtype_bytes = 2 - VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) - bdx = head_dim // VEC_SIZE - bdy = num_heads - max_num_threads_per_block = get_max_num_threads_per_block(target) - while bdx * bdy > max_num_threads_per_block and bdy > 1: - bdy //= 2 - gdy = num_heads // bdy - - # undefined vars used - @T.prim_func(check_well_formed=False) - def merge_state_inplace( - v: T.handle, - s: T.handle, - v_other: T.handle, - s_other: T.handle, - ): - T.func_attr({"tir.is_scheduled": 1}) - N = T.int32(is_size_var=True) - H = T.int32(is_size_var=True) - D = T.int32(is_size_var=True) - - V = T.match_buffer(v, (N, H, D), v_dtype) - S = T.match_buffer(s, (N, H), "float32") - V_other = T.match_buffer(v_other, (N, H, D), v_dtype) - S_other = T.match_buffer(s_other, (N, H), "float32") - - for bx in T.thread_binding(N, thread="blockIdx.x"): - for by in T.thread_binding(gdy, thread="blockIdx.y"): - for ty in T.thread_binding(bdy, thread="threadIdx.y"): - for tx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("merge"): - s_val = _var("float32") - s_other_val = _var("float32") - s_max = _var("float32") - scale = _var("float32") - other_scale = _var("float32") - - v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") - v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") - - s_val[0] = S[bx, ty + by * bdy] - s_other_val[0] = S_other[bx, ty + by * bdy] - s_max[0] = T.max(s_val[0], s_other_val[0]) - s_val[0] = T.exp2(s_val[0] - s_max[0]) - s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) - scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) - other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) - - # load v - for vec in T.vectorized(VEC_SIZE): - v_vec[vec] = V[bx, ty + by * bdy, tx * VEC_SIZE + vec] - # load v_other - for vec in T.vectorized(VEC_SIZE): - v_other_vec[vec] = V_other[bx, ty + by * bdy, tx * VEC_SIZE + vec] - - # merge - for vec in T.serial(VEC_SIZE): - v_vec[vec] = ( - v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] - ) - - # store v - for vec in T.vectorized(VEC_SIZE): - V[bx, ty + by * bdy, tx * VEC_SIZE + vec] = v_vec[vec] - - # store s - S[bx, ty + by * bdy] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] - - # pylint: enable=invalid-name - return merge_state_inplace - - -def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): - tx = 256 if str(target.kind) == "webgpu" else 1024 - - @T.prim_func - def copy_single_page( - pages: T.handle, - src_page_id: T.int64, - tgt_page_id: T.int64, - copy_length: T.int64, - ): - T.func_attr({"tir.is_scheduled": 1}) - num_pages = T.int32() - P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) - - for b in T.thread_binding( - (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" - ): - for t in T.thread_binding(tx, thread="threadIdx.x"): - with T.block("copy"): - T.where(b * tx + t < copy_length * num_heads * head_dim) - vh = T.axis.spatial( - num_heads, - T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), - ) - vp = T.axis.spatial( - copy_length, - (b * tx + t) % (copy_length * head_dim) // head_dim, - ) - vd = T.axis.spatial( - head_dim, - T.Cast( - "int32", - (b * tx + t) % head_dim, - ), - ) - P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, vd] - P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, vd] - - return copy_single_page - - -def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): - tx = 256 if str(target.kind) == "webgpu" else 1024 - - @T.prim_func - def compact_kv_copy( - var_pages: T.handle, - var_copy_length_indptr: T.handle, - var_copy_src_dst_pos: T.handle, - batch_size: T.int32, - ): - T.func_attr({"tir.is_scheduled": 1}) - num_pages = T.int32() - total_copy_length = T.int32() - copy_length_indptr_elem_offset = T.int32() - copy_src_dst_pos_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) - copy_length_indptr = T.match_buffer( - var_copy_length_indptr, - (batch_size + 1,), - "int32", - elem_offset=copy_length_indptr_elem_offset, - ) - copy_src_dst_pos = T.match_buffer( - var_copy_src_dst_pos, - (2, total_copy_length), - "int32", - elem_offset=copy_src_dst_pos_elem_offset, - ) - - for bhd_o in T.thread_binding( - (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" - ): - for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): - b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) - h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads - d: T.int32 = (bhd_o * tx + bhd_i) % head_dim - if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: - for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): - src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] - dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] - pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ - src_pos // 16, 0, h, src_pos % 16, d - ] - pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ - src_pos // 16, 1, h, src_pos % 16, d - ] - - return compact_kv_copy - - if __name__ == "__main__": HEAD_DIMS = [64, 128] DTYPES = ["float16", "float32"] From 6ae29610a531cea66e94f8bdcf96f2c5cbdb3bf9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 9 Aug 2024 09:44:59 -0400 Subject: [PATCH 055/202] [ROCm] Support ROCm 6 (#17256) This PR updates some ROCm modules in order to support ROCm 6. --- cmake/modules/ROCM.cmake | 1 + cmake/utils/FindRCCL.cmake | 2 +- src/runtime/rocm/rocm_device_api.cc | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 37fcd716464e..02c4c739934a 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -23,6 +23,7 @@ if(ROCM_FOUND) # avoid global retrigger of cmake include_directories(SYSTEM ${ROCM_INCLUDE_DIRS}) add_definitions(-D__HIP_PLATFORM_HCC__=1) + add_definitions(-D__HIP_PLATFORM_AMD__=1) endif(ROCM_FOUND) diff --git a/cmake/utils/FindRCCL.cmake b/cmake/utils/FindRCCL.cmake index 93d8c8744630..95cb555178d0 100644 --- a/cmake/utils/FindRCCL.cmake +++ b/cmake/utils/FindRCCL.cmake @@ -32,7 +32,7 @@ macro(find_rccl use_rccl) find_path(RCCL_INCLUDE_DIR NAMES rccl.h) find_library(RCCL_LIBRARY NAMES rccl) else() - find_path(RCCL_INCLUDE_DIR NAMES rccl.h HINTS ${use_rccl} ${use_rccl}/include) + find_path(RCCL_INCLUDE_DIR NAMES rccl.h HINTS ${use_rccl} ${use_rccl}/include ${use_rccl}/include/rccl) find_library(RCCL_LIBRARY NAMES rccl HINTS ${use_rccl} ${use_rccl}/lib) endif() include(FindPackageHandleStandardArgs) diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index e2a5048ca030..c37e9fada5b2 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -113,7 +113,7 @@ class ROCMDeviceAPI final : public DeviceAPI { case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id)); - *rv = prop.gcnArch; + *rv = prop.gcnArchName; return; } case kApiVersion: { From e5f85c0e32046b6b1bdc5bd1a2485c645df4e730 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Sat, 10 Aug 2024 21:55:51 +0530 Subject: [PATCH 056/202] [DLIGHT][ADRENO] Fix for opencl adreno matmul schedule (#17259) Fixed the matmul schedule for the case of epilog blocks --- python/tvm/dlight/gpu/matmul.py | 50 +++++++++++---- tests/python/dlight/test_gpu_matmul.py | 89 ++++++++++++++------------ 2 files changed, 85 insertions(+), 54 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 25cc649b44dd..5fb8e2469d54 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -941,7 +941,7 @@ def get_configs(self, target: Target) -> Config: inner_x=False, ) elif target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("windows" in str(target.host)) + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) ): return Matmul.Config( block_size_x=32, @@ -991,7 +991,10 @@ def is_inner_reduction(block_stmt, iter_infos): end_it = block_stmt.reads[-1].region[-1].min return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R" - if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos): + if ( + target.kind.name == "opencl" + and (("android" in str(target.host)) or ("adreno" in str(target.attrs))) + ) and not is_inner_reduction(block_stmt, iter_infos): ret = self.sch_outer_reduction(sch, config, main_block, blocks) if ret is not None: return ret @@ -1122,6 +1125,16 @@ def sch_outer_reduction( reduction_block: tir.schedule.BlockRV, blocks: List[tir.schedule.BlockRV], ) -> Optional[tir.Schedule]: + + """Get vectorization factor""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + reduction_loops = sch.get_loops(reduction_block) if not len(reduction_loops) == 4: return None @@ -1140,13 +1153,17 @@ def sch_outer_reduction( config.vector_size, config.unroll, ) - - is_dequant_block = len(blocks) > 1 - if is_dequant_block: - compute_block, dequant_block, matmul_block = blocks - sch.compute_inline(compute_block) - else: - (matmul_block,) = blocks + VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize) + dequant_block = None + matmul_block = reduction_block + epilogue_block = None + if blocks[-1] is not matmul_block: + epilogue_block = blocks[-1] + for blk in blocks[:-1]: + if "dequantize" in sch.get(blk).name_hint: + dequant_block = blk + elif blk is not matmul_block: + sch.compute_inline(blk) m = sch.fuse(mb, ms) @@ -1162,12 +1179,13 @@ def sch_outer_reduction( sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv) sch.compute_at(rmat_block, k0) - if is_dequant_block: + if dequant_block is not None: sch.compute_at(dequant_block, k3) sch.reverse_compute_at(wmat_block, mi) sch.set_scope(rmat_block, 0, "shared") sch.set_scope(matmul_block, 0, "local") - if is_dequant_block: + + if dequant_block is not None: sch.set_scope(dequant_block, 0, "local") sch.bind(mo, "blockIdx.y") @@ -1175,7 +1193,7 @@ def sch_outer_reduction( sch.bind(mi, "threadIdx.y") sch.bind(ni, "threadIdx.x") sch.vectorize(sch.get_loops(matmul_block)[-1]) - if is_dequant_block: + if dequant_block is not None: sch.vectorize(sch.get_loops(dequant_block)[-1]) # Co-operative Memory Fetch @@ -1187,7 +1205,7 @@ def sch_outer_reduction( sch.vectorize(wv) # Scale and Quant Cache - if is_dequant_block: + if dequant_block is not None: qb = sch.cache_read(dequant_block, 0, "local") sb = sch.cache_read(dequant_block, 1, "local") sch.compute_at(sb, k1) @@ -1197,5 +1215,11 @@ def sch_outer_reduction( sch.vectorize(sch.get_loops(qb)[-1]) sch.vectorize(sch.get_loops(sb)[-1]) + if epilogue_block is not None: + sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True) + sch.set_scope(wmat_block, 0, "local") + sch.compute_inline(wmat_block) + sch.vectorize(sch.get_loops(epilogue_block)[-1]) + sch.decompose_reduction(matmul_block, k0) return sch diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 4cef7f1c27c3..dc5276e62a5f 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -685,47 +685,54 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), class TestFusedDequantMatmulAndroid(AndroidBeforeAfter): # fmt: off @T.prim_func - def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() - rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") - matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") + T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16") for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(lv840[v_i0 // T.int64(8), v_i1]) + T.reads(lv452[v_i0 // T.int64(8), v_i1]) T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1]) + T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1]) T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) - dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1] + dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1] for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)): with T.block("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rms_norm260[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) + T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) T.writes(matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] + matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) + T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) + T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] @T.prim_func - def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() - rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") - matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") + T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") - rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") + rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local") - lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") - lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") + lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): @@ -743,37 +750,37 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T for ax0 in range(T.int64(4)): for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax1_1 in T.vectorized(T.int64(8)): - with T.block("rms_norm260_pad"): + with T.block("rms_norm130_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(rms_norm260[v0, v1, v2]) - T.writes(rms_norm260_pad_shared[v0, v1, v2]) - rms_norm260_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0)) + T.reads(rms_norm130[v0, v1, v2]) + T.writes(rms_norm130_pad_shared[v0, v1, v2]) + rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) for k_1 in range(T.int64(8)): for ax0 in T.vectorized(T.int64(8)): - with T.block("lv841_local"): + with T.block("lv453_local"): v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv841[v0, v1]) - T.writes(lv841_local[v0, v1]) - lv841_local[v0, v1] = lv841[v0, v1] + T.reads(lv453[v0, v1]) + T.writes(lv453_local[v0, v1]) + lv453_local[v0, v1] = lv453[v0, v1] for k_2 in range(T.int64(4)): for ax0 in T.vectorized(T.int64(8)): - with T.block("lv840_local"): + with T.block("lv452_local"): v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv840[v0, v1]) - T.writes(lv840_local[v0, v1]) - lv840_local[v0, v1] = lv840[v0, v1] + T.reads(lv452[v0, v1]) + T.writes(lv452_local[v0, v1]) + lv452_local[v0, v1] = lv452[v0, v1] for k_3 in range(T.int64(8)): for ax0 in T.vectorized(T.int64(8)): with T.block("dequantize"): v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv840_local[v_i0 // T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1]) + T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) - dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1] + dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] for i0_i1_fused_2 in range(T.int64(4)): for i2_2 in T.vectorized(T.int64(8)): with T.block("matmul_update"): @@ -781,19 +788,19 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm260_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) + T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] - for ax0 in range(T.int64(4)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("matmul_intermediate_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 - (seq_len + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len) - T.reads(matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(matmul_intermediate[v0, v1, v2]) - matmul_intermediate[v0, v1, v2] = matmul_intermediate_pad_local[v0, v1, v2] + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1) + v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2) + T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len) + T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) + T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) + T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] # fmt: on From 2d828f5cc29692546317cb0a2e76ba521b1bd080 Mon Sep 17 00:00:00 2001 From: Weiyi Ding <72555042+DDDVE@users.noreply.github.com> Date: Sun, 11 Aug 2024 00:29:26 +0800 Subject: [PATCH 057/202] =?UTF-8?q?[CompileBugfix][contrib]=20meet=20'base?= =?UTF-8?q?64.h:=20No=20such=20file=20or=20directory'=20and=20'=E2=80=98tv?= =?UTF-8?q?m::runtime::vm::AllocatorType=E2=80=99=20has=20not=20been=20dec?= =?UTF-8?q?lared'=20while=20compiling=20(#17265)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/contrib/torch/pt_call_tvm/tvm_class.cc | 2 +- .../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc index 5e57dc152f11..f5ae95a5a73d 100644 --- a/src/contrib/torch/pt_call_tvm/tvm_class.cc +++ b/src/contrib/torch/pt_call_tvm/tvm_class.cc @@ -167,7 +167,7 @@ class TvmVMModulePack { const auto runtime_create = *tvm::runtime::Registry::Get("runtime._VirtualMachine"); vm_ = runtime_create(exe_); auto init_func = vm_.GetFunction("init", false); - auto alloc_type = static_cast(tvm::runtime::vm::AllocatorType::kPooled); + auto alloc_type = static_cast(tvm::runtime::memory::AllocatorType::kPooled); if (device_type != kDLCPU) { // CPU is required for executing shape functions init_func(static_cast(kDLCPU), 0, alloc_type, device_type, device_id, alloc_type); diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index c77996cf67b6..3e1c7e7c0edf 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -29,7 +29,7 @@ #include #include "../../../runtime/graph_executor/graph_executor_factory.h" -#include "../../support/base64.h" +#include "../../../support/base64.h" #include "runtime_bridge.h" namespace tvm { @@ -209,10 +209,10 @@ inline void b64decode(const std::string b64str, uint8_t* ret) { size_t index = 0; const auto length = b64str.size(); for (size_t i = 0; i < length; i += 4) { - int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; - int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; - int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; - int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; + int8_t ch0 = tvm::support::base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 3]]; uint8_t st1 = (ch0 << 2) + (ch1 >> 4); ret[index++] = st1; if (b64str[i + 2] != '=') { From bed66d20f1640f814b9f27bcc439f8761e3070cf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 10 Aug 2024 10:06:17 -0700 Subject: [PATCH 058/202] [Disco] Disable splitting nccl communicator in single-group (#17264) --- src/runtime/disco/nccl/nccl.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index d35fc911c692..a5240aa2b2c5 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -101,8 +101,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); - NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, - worker->worker_id % group_size, &ctx->group_comm, NULL)); + if (worker->num_groups == 1) { + ctx->group_comm = ctx->global_comm; + } else { + NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, + worker->worker_id % group_size, &ctx->group_comm, NULL)); + } } void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { From b3d01c2295cde9dcd02980bad49fcd9cd3049231 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 11 Aug 2024 13:43:09 -0500 Subject: [PATCH 059/202] [Relax][Bugfix] Preserve dtype in ToMixedPrecision for kNever ops (#17263) Prior to this commit, while an operator with the `MixedPrecisionPolicyKind::kNever` attribute would not be updated from `float32` to `float16`, it would be erroneously updated from `float16` to `float32`. This commit updates `ToMixedPrecision` to preserve the datatype of any arguments used in a `kNever` operation, rather than forcing them to a `float32` datatype. --- src/relax/transform/to_mixed_precision.cc | 69 ++++++++++++------- .../test_transform_to_mixed_precision.py | 34 ++++++++- 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index c844d5935623..1b660b8fecc5 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } Array RemapArgs(const Array& args) { - Array new_args; - for (const auto& arg : args) { - new_args.push_back(VarReplacer::Replace(arg, var_remap_)); - } - return new_args; + return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); }); } // Util function to rewrite the expr to the given dtype @@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator { ReEmitBinding(binding, call_node->args[0]); return; } - DataType to; - ObjectPtr new_call = make_object(*call_node); + + Call new_call = GetRef(call_node); + // We first to remap the args to the current vars according to the var_remap_ - new_call->args = std::move(RemapArgs(call_node->args)); + new_call.CopyOnWrite()->args = RemapArgs(new_call->args); + // Then we rewrite the args according to the policy + std::optional opt_new_dtype = std::nullopt; + if (policy == kAlways) { - to = fp16_; + opt_new_dtype = fp16_; auto attr_map = Op::GetAttrMap("FInferMixedPrecision"); ICHECK(attr_map.count(op)); - auto f = attr_map[op]; - new_call = make_object(*(f(Call(new_call), output_dtype_).get())); + new_call = attr_map[op](new_call, output_dtype_); } else if (policy == kFollow) { - to = AllFP16Castable(new_call->args) ? fp16_ : fp32_; + opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_; } else if (policy == kNever) { - to = fp32_; + // An upstream operation may have changed the datatype of the + // arguments. Because this operation must be provided with + // exactly the same dtype as it previously had, it may require a + // cast back to the original datatype. + + if (!new_call->args.same_as(call_node->args)) { + Array new_typed_args; + for (size_t i = 0; i < call_node->args.size(); i++) { + auto arg = new_call->args[i]; + auto old_ntype = NTypeFrom(call_node->args[i]); + new_typed_args.push_back(RewriteExpr(arg, old_ntype)); + } + new_call.CopyOnWrite()->args = new_typed_args; + } + } else { LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; } - new_call->args = std::move(RewriteArgs(new_call->args, to)); - new_call->struct_info_ = NullOpt; - Expr new_value = builder_->Normalize(Call(new_call)); - if (policy == kAlways && binding->var->IsInstance()) { - // kAlways: store the tensors to fp16 - // But global vars will be stored to the original dtype anyway (see below) - new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_)); - } - if (!binding->var->IsInstance()) { - // Global var: store the tensors to the original dtype - NType to = NTypeFrom(binding->var); - new_value = RewriteExpr(new_value, to); + + Expr new_value = new_call; + if (opt_new_dtype) { + auto new_dtype = opt_new_dtype.value(); + new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype); + new_call.CopyOnWrite()->struct_info_ = NullOpt; + + new_value = builder_->Normalize(Call(new_call)); + + if (!binding->var->IsInstance()) { + // Non-Dataflow var: store the tensors to the original dtype + new_value = RewriteExpr(new_value, NTypeFrom(binding->var)); + } else if (policy == kAlways && binding->var->IsInstance()) { + // kAlways: store the tensors to fp16 + // But non-dataflow vars will be stored to the original dtype anyway (see above) + new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype)); + } } + ReEmitBinding(binding, builder_->Normalize(new_value)); } diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 4ddf47b462ad..ed10fc95c723 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -20,7 +20,7 @@ from tvm import relax import tvm.testing from tvm.relax.transform import ToMixedPrecision -from tvm.script.parser import ir as I, relax as R +from tvm.script.parser import ir as I, relax as R, tir as T def _assert_test(input, expected=None, expected2=None): @@ -614,8 +614,8 @@ def main( x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32") ) -> R.Tensor(None, "float32", ndim=4): with R.dataflow(): - gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1)) - gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1) + gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, padding=(1, 1)) + gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, axis=1) gv2 = R.add(gv, gv1) R.output(gv2) return gv2 @@ -1036,5 +1036,33 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_call_tir_with_float16_args(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([64], "float16")): + cls = Before + with R.dataflow(): + B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16")) + C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16")) + R.output(C) + return C + + @T.prim_func + def tir_identity( + Input: T.Buffer(64, "float16"), + Output: T.Buffer(64, "float16"), + ): + for i in range(64): + with T.block("copy"): + vi = T.axis.remap("S", [i]) + Output[vi] = Input[vi] + + Expected = Before + + After = ToMixedPrecision()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() From 02f48828e4b56995be0021c9a98e1705a837e712 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Aug 2024 07:36:17 -0500 Subject: [PATCH 060/202] [FFI] Re-introduce the boxed primitive values (#17257) * Revert "Revert "[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool" (#17252)" This reverts commit 11be83262024fa73a36b744cfd2fc334d5b5e49d. * [FFI] Re-introduce the boxed primitive values Initially introduced in https://github.com/apache/tvm/pull/16183, these changes were reverted in https://github.com/apache/tvm/pull/17252 due to performance degredation in some Relax models. This could occur when a model contained a large number of calls to `"vm.builtin.tuple_getitem"`, which may occur when model weights are provided as a tuple. This PR re-applies the changes from https://github.com/apache/tvm/pull/16183, but with the performance degredation resolved. The root cause was unnecessary type-checking when converting from an untyped `tvm::ArrayNode*` to the typed `tvm::Array`, in the case where `T` is `ObjectRef`. * Correct typo from T to U --- include/tvm/ir/attrs.h | 76 +- include/tvm/ir/expr.h | 130 ++- include/tvm/ir/transform.h | 34 +- include/tvm/meta_schedule/schedule_rule.h | 8 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/runtime/c_runtime_api.h | 5 +- .../tvm/runtime/container/boxed_primitive.h | 143 ++++ include/tvm/runtime/container/variant.h | 2 +- include/tvm/runtime/ndarray.h | 2 + include/tvm/runtime/packed_func.h | 756 ++++++++++++++---- include/tvm/target/target.h | 10 +- include/tvm/target/target_kind.h | 4 +- include/tvm/tir/expr.h | 57 ++ include/tvm/tir/function.h | 2 +- include/tvm/tir/schedule/schedule.h | 5 +- python/tvm/_ffi/_ctypes/object.py | 22 + python/tvm/_ffi/_ctypes/packed_func.py | 7 +- python/tvm/_ffi/_ctypes/types.py | 3 + python/tvm/_ffi/_cython/base.pxi | 5 +- python/tvm/_ffi/_cython/object.pxi | 10 + python/tvm/_ffi/_cython/packed_func.pxi | 9 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/driver/tvmc/registry.py | 22 +- python/tvm/ir/attrs.py | 2 +- python/tvm/ir/expr.py | 5 +- python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/op/statistical.py | 22 +- python/tvm/relax/testing/ast_printer.py | 18 +- python/tvm/relax/training/setup_trainer.py | 4 +- python/tvm/relax/utils.py | 3 + .../relay/backend/contrib/ethosu/legalize.py | 2 +- python/tvm/relay/op/_tensor_grad.py | 3 + python/tvm/relay/op/_transform.py | 8 +- python/tvm/relay/op/contrib/ethosu.py | 4 +- python/tvm/relay/op/transform.py | 25 +- .../transform/fake_quantization_to_integer.py | 5 +- python/tvm/runtime/__init__.py | 4 +- python/tvm/runtime/container.py | 38 + python/tvm/runtime/object_generic.py | 75 +- python/tvm/script/parser/tir/parser.py | 2 + python/tvm/te/hybrid/calls.py | 12 +- python/tvm/te/hybrid/parser.py | 4 +- python/tvm/te/hybrid/utils.py | 28 +- python/tvm/te/operation.py | 1 - python/tvm/te/tensor.py | 11 +- python/tvm/tir/__init__.py | 1 + python/tvm/tir/expr.py | 4 + python/tvm/tir/ir_builder.py | 6 +- python/tvm/tir/op.py | 151 ++-- python/tvm/tir/schedule/trace.py | 15 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 2 +- python/tvm/topi/cuda/batch_matmul.py | 8 +- rust/tvm-rt/src/module.rs | 5 +- rust/tvm-sys/src/packed_func.rs | 35 +- src/auto_scheduler/compute_dag.cc | 16 +- .../search_policy/sketch_policy_rules.cc | 3 +- src/auto_scheduler/search_policy/utils.h | 12 +- .../msc/core/printer/msc_base_printer.cc | 9 + .../msc/core/printer/prototxt_printer.cc | 4 + src/contrib/msc/core/utils.cc | 4 + src/driver/driver_api.cc | 5 +- src/ir/attrs.cc | 89 +++ src/ir/expr.cc | 17 +- src/ir/transform.cc | 41 +- src/meta_schedule/database/database_utils.cc | 10 +- src/meta_schedule/database/json_database.cc | 4 +- .../mutator/mutate_thread_binding.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 6 +- src/meta_schedule/mutator/mutate_unroll.cc | 6 +- .../schedule/cuda/thread_bind.cc | 6 +- .../schedule_rule/cross_thread_reduction.cc | 8 +- .../schedule_rule/multi_level_tiling.cc | 5 +- .../parallel_vectorize_unroll.cc | 6 +- .../schedule_rule/schedule_rule.cc | 12 +- src/meta_schedule/utils.h | 38 +- src/node/boxed_primitive.cc | 134 ++++ src/node/script_printer.cc | 16 +- src/node/structural_equal.cc | 37 +- src/relax/backend/vm/codegen_vm.cc | 2 + src/relax/backend/vm/codegen_vm_tir.cc | 30 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/manipulate.cc | 6 +- src/relax/op/tensor/manipulate.h | 4 +- .../backend/contrib/cmsisnn/compiler_attrs.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- src/relay/backend/contrib/cutlass/target.cc | 18 +- .../backend/contrib/ethosn/ethosn_api.cc | 6 +- src/relay/backend/contrib/ethosu/codegen.cc | 3 +- .../backend/contrib/ethosu/preprocess.cc | 4 +- .../contrib/example_target_hooks/target.cc | 2 +- src/relay/backend/contrib/tensorrt/codegen.cc | 4 +- src/relay/backend/contrib/tensorrt/target.cc | 14 +- src/relay/backend/contrib/uma/targets.cc | 7 +- src/relay/backend/executor.cc | 10 +- src/relay/backend/runtime.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 36 + src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 48 +- .../transforms/combine_parallel_op_batch.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/higher_order_gradient.cc | 2 - src/relay/transforms/to_mixed_precision.cc | 4 +- src/runtime/boxed_primitive.cc | 65 ++ src/runtime/crt/common/crt_runtime_api.c | 8 +- src/runtime/disco/bcast_session.cc | 8 +- src/runtime/minrpc/rpc_reference.h | 8 + src/runtime/relax_vm/builtin.cc | 10 +- .../printer/doc_printer/python_doc_printer.cc | 23 +- src/script/printer/ir/misc.cc | 15 + src/script/printer/relax/tir.cc | 6 +- src/support/array.h | 52 +- src/support/ffi_testing.cc | 52 ++ src/target/llvm/codegen_cpu.cc | 29 +- src/target/llvm/llvm_instance.cc | 14 +- src/target/tag.cc | 66 +- src/target/target.cc | 66 +- src/target/target_kind.cc | 137 ++-- src/te/operation/compute_op.cc | 26 +- src/te/operation/create_primfunc.cc | 15 +- src/te/operation/placeholder_op.cc | 12 +- src/te/schedule/schedule_dataflow_rewrite.cc | 7 +- .../analysis/calculate_allocated_memory.cc | 2 +- src/tir/ir/expr.cc | 20 +- src/tir/ir/function.cc | 7 + src/tir/ir/specialize.cc | 2 +- src/tir/ir/stmt.cc | 32 +- src/tir/ir/utils.cc | 68 ++ src/tir/ir/utils.h | 51 ++ src/tir/op/op.cc | 16 +- src/tir/schedule/concrete_schedule.cc | 14 +- src/tir/schedule/concrete_schedule.h | 5 +- src/tir/schedule/instruction_traits.h | 5 + src/tir/schedule/primitive.h | 5 +- src/tir/schedule/primitive/annotate.cc | 3 + src/tir/schedule/primitive/sampling.cc | 36 +- src/tir/schedule/trace.cc | 12 +- src/tir/schedule/traced_schedule.cc | 6 +- src/tir/schedule/traced_schedule.h | 5 +- .../transforms/inline_private_functions.cc | 2 +- src/tir/transforms/ir_utils.h | 1 + src/tir/transforms/lower_tvm_builtin.cc | 2 + src/tir/transforms/make_packed_api.cc | 45 +- tests/cpp/relay/backend/runtime_test.cc | 10 +- tests/cpp/target_test.cc | 56 +- .../test_runtime_packed_func.py | 18 +- .../arith/test_arith_canonical_simplify.py | 23 +- .../arith/test_arith_iter_affine_map.py | 35 +- .../test_arith_narrow_predicate_expression.py | 21 +- .../arith/test_arith_rewrite_simplify.py | 63 +- .../test_arith_solve_linear_equations.py | 15 +- .../test_arith_solve_linear_inequality.py | 11 +- .../codegen/test_target_codegen_cuda.py | 2 +- .../codegen/test_target_codegen_llvm.py | 41 + .../ir/test_container_structural_equal.py | 30 +- tests/python/ir/test_ir_container.py | 15 +- tests/python/ir/test_ir_type.py | 9 +- .../test_distributed_tvmscript_printer.py | 4 +- tests/python/relax/test_ast_printer.py | 2 +- .../relax/test_backend_dispatch_sort_scan.py | 10 +- .../relax/test_tvmscript_printer_relax.py | 6 +- tests/python/relax/test_vm_build.py | 2 +- tests/python/relax/test_vm_codegen_tir.py | 5 +- tests/python/relay/test_dataflow_pattern.py | 3 +- tests/python/relay/test_executor.py | 2 +- tests/python/relay/test_runtime.py | 4 +- tests/python/relay/test_type_infer.py | 65 +- .../python/runtime/test_runtime_container.py | 130 ++- tests/python/te/test_te_schedule_tensorize.py | 20 +- tests/python/te/test_te_tag.py | 10 +- tests/python/tir-base/test_lower_build.py | 2 +- tests/python/tir-base/test_tir_buffer.py | 17 +- tests/python/tir-base/test_tir_index_map.py | 55 +- tests/python/tir-base/test_tir_nodes.py | 27 +- .../test_tir_schedule_sampling.py | 2 +- .../tir-schedule/test_tir_schedule_state.py | 4 +- ...est_tir_transform_compact_buffer_region.py | 71 +- ...tir_transform_instrument_bound_checkers.py | 8 +- .../test_tir_transform_make_packed_api.py | 139 ++++ .../test_tir_transform_storage_rewrite.py | 4 +- .../tvmscript/test_tvmscript_error_report.py | 17 +- .../tvmscript/test_tvmscript_printer_tir.py | 12 +- .../tvmscript/test_tvmscript_roundtrip.py | 31 +- vta/python/vta/transform.py | 13 +- 184 files changed, 3278 insertions(+), 1225 deletions(-) create mode 100644 include/tvm/runtime/container/boxed_primitive.h create mode 100644 src/node/boxed_primitive.cc create mode 100644 src/runtime/boxed_primitive.cc create mode 100644 src/tir/ir/utils.cc create mode 100644 src/tir/ir/utils.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 81611b1a535a..d038d5f59a5f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -265,7 +265,16 @@ class DictAttrs : public Attrs { auto it = node->dict.find(attr_key); if (it != node->dict.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + Optional obj = ret; + return obj; } else { return default_value; } @@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the DictAttrs, but overrides attributes with the + * entries from \p attrs. + * + * \param attrs The DictAttrs to update + * + * \param new_attrs Key/values attributes to add to \p attrs. + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); + +/*! + * \brief Copy the DictAttrs, but overrides a single attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The update to insert or update. + * + * \param value The new value of the attribute + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value); + +inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) { + return WithAttr(std::move(attrs), String(key), std::move(value)); +} + +/*! + * \brief Copy the DictAttrs, but without a specific attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The key to remove + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); + /*! * \brief Copy the function or module, but overrides * the attribute value key with the value. @@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } + node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); + return input; } @@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - for (const auto& pair : attrs) { - node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); - } - } else { - node->attrs = DictAttrs(std::move(attrs)); - } + + node->attrs = WithAttrs(std::move(node->attrs), attrs); + return input; } @@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - if (input->attrs.defined()) { - TNode* node = input.CopyOnWrite(); - node->attrs.CopyOnWrite()->dict.erase(attr_key); - } + TNode* node = input.CopyOnWrite(); + node->attrs = WithoutAttr(std::move(node->attrs), attr_key); + return input; } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 9b522389227a..efde52385177 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// common rule for RetValue and ArgValue + +// Automatic conversion into IntImm, Integer, and Bool, when called +// through the FFI. Automatic conversions into PrimExpr are +// registered in "tvm/tir/expr.h", as it includes conversions to the +// TIR-only StringImm. +// +// While the FFI only requires the From() method, these +// implementations also define a TryFrom() method to avoid duplicate +// logic in the PrimExpr conversion. + template <> -struct PackedFuncValueConverter { - static PrimExpr From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return PrimExpr(ObjectPtr(nullptr)); - } - if (val.type_code() == kDLInt) { - int64_t value = val.operator int64_t(); - if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { - return IntImm(runtime::DataType::Int(64), value); - } - return IntImm(runtime::DataType::Int(32), val.operator int()); - } - if (val.type_code() == kDLFloat) { - return FloatImm(runtime::DataType::Float(32), val.operator double()); +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsInt()) { + int64_t value = opt.value(); + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); + } else if (auto opt = val.TryAsBool()) { + return IntImm(DataType::Int(32), opt.value()); + } else { + return NullOpt; } + } - return PrimExpr::FromObject_(val.AsObjectRef()); + template + static tvm::IntImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } } }; template <> struct PackedFuncValueConverter { - static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); + template + static tvm::Integer From(const PODSubclass& val) { + if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return Integer(opt.value()); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); } }; template <> struct PackedFuncValueConverter { - static tvm::Bool From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Bool(ObjectPtr(nullptr)); + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + return tvm::Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int value = opt.value(); + ICHECK(value == 0 || value == 1) + << "ValueError: boolean value can only be 0 or 1, but get " << value; + return tvm::Bool(static_cast(value)); + } else { + return NullOpt; + } + } + + template + static tvm::Bool From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - int v = val.operator int(); - ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; - return Bool(static_cast(v)); + } +}; + +template <> +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return NullOpt; + } + } + + template + static tvm::FloatImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +/* \brief Backwards compatibility wrapper for IntImm arguments + * + * In previous versions of TVM, IntImm was the default FFI type for + * integer arguments, instead of runtime::Int. For backwards + * compatibility where the callee has been updated to expected a + * runtime::Int, the caller has not been updated to provide a + * runtime::Int (e.g. relay script parsing), and the auto-unboxing of + * runtime::Int does not apply (e.g. making an `Array`), + * allow the IntImm to be generated. + */ +template <> +struct PackedFuncValueConverter { + template + static runtime::Int From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return runtime::Int(val.template AsObjectRef()->value); + } else { + return val.template AsObjectRef(); } - return val.AsObjectRef(); } }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index adf332525020..5828d98206ad 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -271,7 +271,36 @@ class PassContext : public ObjectRef { using ValueNodeType = typename ValueType::ContainerType; // NOTE: we could further update the function later. uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); - RegisterConfigOption(key, tindex); + auto type_key = runtime::Object::TypeIndex2Key(tindex); + + auto* reflection = ReflectionVTable::Global(); + + auto legalization = [=](ObjectRef obj) -> ObjectRef { + if (obj->IsInstance::ContainerType>()) { + return reflection->CreateObject(type_key, Downcast>(obj)); + } else { + // Backwards compatibility for config options defined prior to + // https://github.com/apache/tvm/pull/16183. This commit + // changed the default FFI conversion of python integers from + // `tvm::IntImm` to `runtime::Int`. + // + // This backwards compatibility fix can be removed when all + // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are + // updated to use `runtime::Int` and `runtime::Bool`. + TVMRetValue ret; + ret = obj; + try { + ValueType legalized = ret; + return legalized; + } catch (Error& err) { + LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key + << ", but received error when converting to this type.\n" + << err.what(); + } + } + }; + + RegisterConfigOption(key, tindex, legalization); return tindex; } @@ -285,7 +314,8 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization); // Classes to get the Python `with` like syntax. friend class Internal; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index d91812fb55cb..90aec05187eb 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 249b9cd0e50d..91020fc7443b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - ObjectRef indices_or_sections; + Variant> indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f1046ef24266..b4c653a0a59e 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -81,6 +81,7 @@ #ifdef __cplusplus extern "C" { #endif +#include #include #include @@ -186,11 +187,12 @@ typedef enum { kTVMBytes = 12U, kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, + kTVMArgBool = 15U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 15U, + kTVMExtBegin = 16U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, // The following section of code is used for non-reserved types. @@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; + bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h new file mode 100644 index 000000000000..8d01b5dc17b5 --- /dev/null +++ b/include/tvm/runtime/container/boxed_primitive.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/boxed_primitive.h + * \brief Runtime container types for primitives stored as ObjectRef. + */ +#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ +#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +/* \brief Provide the BoxNode type traits in templated contexts + * + * The Box class is used in many templated contexts, and is easier + * to have templated over the primitive type. + * + * However, much of the TVM type system depends on classes having a + * unique name. For example, the use of `Object::IsInstance` depends + * on `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will + * result in duplicate indices, and invalid downcasting. Furthermore, + * the name must be specified in the Python FFI using + * `tvm._ffi.register_object`. This prevents use of + * `typeid(T)::name()` to build a unique name, as the name is not + * required to be human-readable or consistent across compilers. + * + * This utility struct should be specialized over the primitive type + * held by the box, to allow explicit listing of the `_type_key` and + * other similar tratis. + * + * Note: This should only contain traits that are required at runtime, + * and should *not* contain extensions for features that are only + * available at compile-time. For integration with compile-time-only + * functionality (e.g. StructuralHash, StructuralEqual), see + * `BoxNodeCompileTimeTraits` in `src/node/boxed_primitive.cc`. + */ +template +struct BoxNodeRuntimeTraits; + +} // namespace detail + +template +class BoxNode : public Object { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + explicit BoxNode(Prim value) : value(value) {} + + /*! \brief The boxed value */ + Prim value; + + static constexpr const char* _type_key = detail::BoxNodeRuntimeTraits::_type_key; + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); +}; + +template +class Box : public ObjectRef { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + Box(Prim value) : ObjectRef(make_object>(value)) {} // NOLINT(*) + + operator Prim() const { return (*this)->value; } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); +}; + +/*! \brief Boxed version of C++ int64_t + * + * Can be used to store POD integer values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + */ +using Int = Box; + +/*! \brief Boxed version of C++ double + * + * Can be used to store POD floating-point values as a TVM ObjectRef. + * Used for FFI handling, and for storing POD types inside TVM + * containers. + */ +using Float = Box; + +/*! \brief Boxed version of C++ bool + * + * Can be used to store POD boolean values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + * + * When passing from Python to C++, TVM PackedFunc conversion follow + * C++ conversion rules, and allow bool->int and int->bool + * conversions. When passing from C++ to Python, the types are + * returned as bool or int. If the C++ function uses ObjectRef to + * hold the object, a Python to C++ to Python round trip will preserve + * the distinction between bool and int. + */ +using Bool = Box; + +namespace detail { +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxInt"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxFloat"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxBool"; +}; +} // namespace detail + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 7953ac47c1cf..e8defa4e6fee 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_same_v || ...); + static constexpr bool is_variant = (std::is_base_of_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 3eb225fccffe..fef61a753103 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -226,6 +226,8 @@ class NDArray : public ObjectRef { protected: friend class TVMPODValue_; + template + friend class TVMPODValue_CRTP_; friend class TVMRetValue; friend class TVMArgsSetter; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..91e53055b708 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -429,9 +431,11 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) +#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ + "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) +#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -487,6 +491,11 @@ struct ObjectTypeChecker> { if (!ptr->IsInstance()) { return String(ptr->GetTypeKey()); } + + if constexpr (std::is_same_v) { + return NullOpt; + } + const ArrayNode* n = static_cast(ptr); for (size_t i = 0; i < n->size(); i++) { const ObjectRef& p = (*n)[i]; @@ -500,6 +509,8 @@ struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; + if constexpr (std::is_same_v) return true; + const ArrayNode* n = static_cast(ptr); for (const ObjectRef& p : *n) { if (!ObjectTypeChecker::Check(p.get())) { @@ -510,15 +521,27 @@ struct ObjectTypeChecker> { } static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; + template struct ObjectTypeChecker> { static Optional CheckAndGetMismatch(const Object* ptr) { if (ptr == nullptr) return NullOpt; if (!ptr->IsInstance()) return String(ptr->GetTypeKey()); + + if constexpr (std::is_same_v && std::is_same_v) { + return NullOpt; + } + const MapNode* n = static_cast(ptr); for (const auto& kv : *n) { - Optional key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - Optional value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + Optional key_type = NullOpt; + if constexpr (!std::is_same_v) { + key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + } + Optional value_type = NullOpt; + if constexpr (!std::is_same_v) { + value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + } if (key_type.defined() || value_type.defined()) { std::string key_name = key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker::TypeName(); @@ -532,10 +555,19 @@ struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; + + if constexpr (std::is_same_v && std::is_same_v) { + return true; + } + const MapNode* n = static_cast(ptr); for (const auto& kv : *n) { - if (!ObjectTypeChecker::Check(kv.first.get())) return false; - if (!ObjectTypeChecker::Check(kv.second.get())) return false; + if constexpr (!std::is_same_v) { + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + } + if constexpr (!std::is_same_v) { + if (!ObjectTypeChecker::Check(kv.second.get())) return false; + } } return true; } @@ -545,40 +577,43 @@ struct ObjectTypeChecker> { } }; +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + return ObjectTypeChecker::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { return ObjectTypeChecker::Check(ptr); } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { return ObjectTypeChecker::TypeName(); } +}; + +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + auto try_first = ObjectTypeChecker::CheckAndGetMismatch(ptr); + if (!try_first.defined()) { + return try_first; + } + + return ObjectTypeChecker>::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { + return ObjectTypeChecker::Check(ptr) || + ObjectTypeChecker>::Check(ptr); + } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { + return ObjectTypeChecker::TypeName() + ", " + + ObjectTypeChecker>::VariantNames(); + } +}; + /*! * \brief Internal base class to * handle conversion to POD values. */ class TVMPODValue_ { public: - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (type_code_ == kDLInt) { - return static_cast(value_.v_int64); - } - TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); - return value_.v_float64; - } - operator int64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator uint64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator int() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - ICHECK_LE(value_.v_int64, std::numeric_limits::max()); - ICHECK_GE(value_.v_int64, std::numeric_limits::min()); - return static_cast(value_.v_int64); - } - operator bool() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64 != 0; - } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; @@ -628,12 +663,39 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; + + std::optional TryAsBool() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kTVMArgBool) { + return value_.v_bool; + } else { + return std::nullopt; + } + } + + std::optional TryAsInt() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLInt) { + return value_.v_int64; + } else { + return std::nullopt; + } + } + + std::optional TryAsFloat() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLFloat) { + return value_.v_float64; + } else { + return std::nullopt; + } + } protected: friend class TVMArgsSetter; @@ -648,13 +710,90 @@ class TVMPODValue_ { int type_code_; }; +/*! \brief A utility class that adds methods useful for each POD type + * + * These cannot be provided in the base PODValue_ class, because + * TVMArgValue and TVMRetValue have different semantics for kTVMStr + * and kTVMBytes. + * + * kTVMStr: + * + * For `TVMArgValue`, the active variant is `v_str`, a `const + * char*`. For `TVMRetValue`, the active variant is `v_handle`, + * and should be cast from `void*` to `std::string*`. + * + * kTVMBytes: + * + * The active variant is `v_handle`, a `void*`. For + * `TVMArgValue`, should be cast to `TVMByteArray*`. For + * `TVMRetValue`, should be cast to `std::string*`. + * + * When converting into an `ObjectRef`, a string may be used to build + * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use + * different representations for strings, any utility funciton which + * might attempt a conversion to an `ObjectRef` must be performed + * within a context that is aware of the derived class. + */ +template +class TVMPODValue_CRTP_ : public TVMPODValue_ { + public: + using TVMPODValue_::TVMPODValue_; + + // ObjectRef handling + template ::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (auto opt = TryAsFloat()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); + } + } + operator int64_t() const { + if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } + operator uint64_t() const { return operator int64_t(); } + operator int() const { + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; + } + operator bool() const { + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } +}; + /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ -class TVMArgValue : public TVMPODValue_ { +class TVMArgValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMArgValue() {} @@ -663,21 +802,21 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; // conversion operator. operator std::string() const { @@ -714,15 +853,15 @@ class TVMArgValue : public TVMPODValue_ { * * \note For internal development purpose only. */ -class TVMMovableArgValue_ : public TVMPODValue_ { +class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; @@ -804,7 +943,7 @@ class TVMMovableArgValueWithContext_ { * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ -class TVMRetValue : public TVMPODValue_ { +class TVMRetValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMRetValue() {} @@ -812,28 +951,28 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -901,8 +1040,8 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; + this->SwitchToPOD(kTVMArgBool); + value_.v_bool = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -974,7 +1113,8 @@ class TVMRetValue : public TVMPODValue_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || + type_code == kTVMArgBool); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -989,9 +1129,9 @@ class TVMRetValue : public TVMPODValue_ { } // ObjectRef handling template ::value>::type> + typename = typename std::enable_if_t>> inline TVMRetValue& operator=(TObjectRef other); - template ::value>::type> + template >> inline operator T() const; private: @@ -1019,9 +1159,11 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMObjectHandle: { - // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject(kTVMObjectHandle, - GetObjectPtr(static_cast(other.value_.v_handle))); + // We already known it is not NDArray/Module, but + // operator=(ObjectRef) also handles conversions from wrappers + // around primitive types. For NDArray/Module, the duplicate + // checks are removed with if constexpr. + operator=(other.operator ObjectRef()); break; } case kTVMObjectRValueRefArg: { @@ -1265,6 +1407,8 @@ inline const char* ArgTypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; + case kTVMArgBool: + return "bool"; case kDLUInt: return "uint"; case kDLFloat: @@ -1686,6 +1830,10 @@ class TVMArgsSetter { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } + TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { + values_[i].v_bool = value; + type_codes_[i] = kTVMArgBool; + } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); ICHECK_LE(value, static_cast(std::numeric_limits::max())); @@ -1951,38 +2099,110 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (value.defined()) { - Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + if (!value.defined()) { + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; + return; + } + + Object* ptr = value.data_.data_; + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - } else if (std::is_rvalue_reference::value) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; - } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + return; + } + } + + // Like with BoxInt, unwrap any BoxBool instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_bool = static_cast(ptr)->value; + type_codes_[i] = kTVMArgBool; + return; + } + } + + // If a boxed integer is being returned, always unbox it to the + // primitive type. This must be checked at the PackedFunc level to + // ensure that a boxed primitive argument is round-tripped correctly + // when the boxing is no longer required. + // + // For example, consider a PackedFunc with signature `ObjectRef + // func(Array)`, and returns the first element of that + // array. When passing a Python array `[5, 17.5, "hello"]`, the + // items are converted to `[Box(5), Box(17.5), + // String("hello")]` in order to provide an `Array`. + // + // If we had no additional conversions, the caller would receive the + // return value as a `Box(5)`, which would be unexpected and + // require additional unwrapping. We could perform this check + // inside the PackedFunc, but that would require a large amount of + // duplicated checked, and would require explicit handling of + // `TVMRetValue`. Instead, this conversion is checked in the FFI + // return value, to ensure that boxing/unboxing is applied + // consistently. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_int64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgInt; + return; } + } + + // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_float64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgFloat; + return; + } + } + + // Final fallback, if the ObjectRef has no special cases that must + // be expressed within the TVMRetValue. + if constexpr (std::is_rvalue_reference_v) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } } +template template -inline bool TVMPODValue_::IsObjectRef() const { +inline bool TVMPODValue_CRTP_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { @@ -2012,8 +2232,9 @@ inline bool TVMPODValue_::IsObjectRef() const { ObjectTypeChecker::Check(static_cast(value_.v_handle))); } +template template -inline TObjectRef TVMPODValue_::AsObjectRef() const { +inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { static_assert(std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; @@ -2023,8 +2244,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + + // NOTE: The following code uses "if constexpr" wherever possible to + // minimize the number of runtime checks. + if constexpr (std::is_base_of_v) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2033,7 +2256,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2041,7 +2265,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2049,6 +2274,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2062,51 +2288,152 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } else if (std::is_base_of::value && - type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else if (std::is_base_of::value && - type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgInt) { + return Int(value_.v_int64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgFloat) { + return Float(value_.v_float64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgBool) { + return Bool(value_.v_bool); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMStr || type_code_ == kTVMBytes) { + // This step is the reason why `AsObjectRef` cannot be provided + // in the base `TVMPODValue_` class. Because `TVMArgValue` and + // `TVMRetValue` have different implementations of `operator + // std::string`, with different interpretations of `kTVMStr` and + // `kTVMBytes`, we must delegate to those implementations. + // + // This could be done with a pure virtual method in + // `TVMPODValue_`, but that would require a vtable lookup during + // FFI conversions, imposing a runtime overhead. + return String(static_cast(this)->operator std::string()); + } + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(NDArray(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(Module(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(PackedFunc(std::move(other.data_))); + + if (ptr) { + // Check for special cases of ObjectRef that have explicit + // representation within the TVMRetValue structure. + // (e.g. Unboxing of `runtime::Int` into a primitive integer + // with type code kTVMArgInt.) The checks below are written to + // handle three distinct cases. + // + // 1. If TObjectRef is a subclass of TSpecialCase, the special + // case applies, and can be handled without a runtime check. + // No runtime checks should be performed. + // + // 2. If TSpecialCase is a subclass of TObjectRef, the special + // case might apply, and requires a runtime check. + // + // 3. If neither TObjectRef nor TSpecialCase is a subclass of + // the other, then the special case does not apply. No + // runtime checks should be performed. + // + // Use of `if constexpr` ensures that the C++ subclass checks + // are applied when compiling TVM, and runtime overhead are only + // present when they may be applicable. + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(NDArray(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(Module(std::move(other.data_))); + } } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(PackedFunc(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + bool value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + int64_t value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + double value = static_cast(ptr)->value; + return operator=(value); + } + } + + // If the object being stored is not one of the special cases, + // it is stored as an ObjectRef. SwitchToObject(kTVMObjectHandle, std::move(other.data_)); + } else { + // No object is present, set to an explicitly null handle. When + // returning to a Python callee, this will be converted to + // `None`. SwitchToPOD(kTVMNullptr); value_.v_handle = nullptr; } + return *this; } @@ -2139,20 +2466,155 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); + template + static String From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return val.template AsObjectRef(); } else { return tvm::runtime::String(val.operator std::string()); } } +}; - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); +template +struct PackedFuncValueConverter> { + static Array From(const TVMArgValue& val) { + auto untyped_array = val.AsObjectRef>(); + + if constexpr (std::is_same_v) { + return untyped_array; + } + + // Attempt to convert each item of the array into the desired + // type. If the items do not require a conversion, no copies are + // made. + return untyped_array.Map([](ObjectRef item) { + // Recursively apply any conversions that have been registered + // with TVM's FFI. + // + // For example, a function that accepts `Array` may + // be called from python with argument `[1,2]`. By the time + // `PackedFuncValueConverter::From` is called, the python list + // has been converted to `Array`, with contents + // converted into `runtime::Int`. Converting the `ObjectRef` + // to `TVMArgValue` unboxes the `runtime::Int` back into a + // primitive with type code `kTVMArgInt`. This primitive can + // then be converted to a PrimExpr using + // `PackedFuncValueConverter::From`. + // + // The use of two conversions, first from python `int` to + // `runtime::Int` and then from `runtime::Int` to `PrimExpr`, + // is a result of the split between `libtvm_runtime.so` and + // `libtvm.so`. The FFI must function correctly in both + // cases, and so conversions applied by default in the Python + // FFI implementation may only produce types that are + // available in both libraries. In the C++ FFI implementation + // (i.e. this file), libtvm.so may apply additional + // conversions that are not present in libtvm_runtime.so. + TVMValue value; + int type_code; + TVMArgsSetter setter(&value, &type_code); + setter(0, item); + TVMArgValue arg(value, type_code); + return PackedFuncValueConverter::From(arg); + }); + } + static Array From(const TVMRetValue& val) { + auto untyped_array = val.AsObjectRef>(); + + if constexpr (std::is_same_v) { + return untyped_array; + } + + return untyped_array.Map([](ObjectRef item) { + TVMRetValue item_val; + item_val = std::move(item); + return PackedFuncValueConverter::From(item_val); + }); + } +}; + +template +struct PackedFuncValueConverter> { + static Map From(const TVMArgValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if constexpr (std::is_same_v && std::is_same_v) { + return Downcast>(untyped_map); + } + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + if constexpr (std::is_same_v) { + return kv.first; + } else { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.first); + TVMArgValue pod_arg(pod_value, type_code); + return PackedFuncValueConverter::From(pod_arg); + } + }(); + U new_value = [&]() { + if constexpr (std::is_same_v) { + return kv.second; + } else { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.second); + TVMArgValue key_arg(pod_value, type_code); + return PackedFuncValueConverter::From(key_arg); + } + }(); + output.Set(new_key, new_value); + } + return output; + } + static Map From(const TVMRetValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if constexpr (std::is_same_v && std::is_same_v) { + return Downcast>(untyped_map); + } + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + if constexpr (std::is_same_v) { + return kv.first; + } else { + TVMRetValue pod; + pod = kv.first; + return PackedFuncValueConverter::From(pod); + } + }(); + U new_value = [&]() { + if constexpr (std::is_same_v) { + return kv.second; + } else { + TVMRetValue pod; + pod = kv.second; + return PackedFuncValueConverter::From(pod); + } + }(); + output.Set(new_key, new_value); } + return output; } }; @@ -2181,7 +2643,7 @@ struct PackedFuncValueConverter> { return opt.value(); } - if (auto opt = TryValueConverter(val)) { + if (auto opt = TryValueConverter(val)) { return opt.value(); } @@ -2192,10 +2654,10 @@ struct PackedFuncValueConverter> { << " but got " << ArgTypeCode2Str(val.type_code()); } - template - static Optional TryAsObjectRef(const TVMPODValue_& val) { - if (val.IsObjectRef()) { - return VType(val.AsObjectRef()); + template + static Optional TryAsObjectRef(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return VType(val.template AsObjectRef()); } else if constexpr (sizeof...(VarRest)) { return TryAsObjectRef(val); } else { @@ -2203,15 +2665,15 @@ struct PackedFuncValueConverter> { } } - template + template static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const InternalError&) { + } catch (const Error&) { } if constexpr (sizeof...(VarRest)) { - return TryValueConverter(val); + return TryValueConverter(val); } else { return NullOpt; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index d47ac94e067e..4c1d1fc1f3d2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,7 +113,15 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + return ret; } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 130aea32f844..6b3b9c31a645 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index d9b65dc8745c..28cb022151d2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1155,6 +1155,63 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm +namespace tvm { +namespace runtime { + +// Automatic conversion into PrimExpr, when called through the FFI. +// Automatic conversions into IntImm, Integer, and Bool are registered +// in "tvm/ir/expr.h", as they are currently in use outside of TIR. + +template <> +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + auto type_code = val.type_code(); + bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || + type_code == kTVMStr || val.template IsObjectRef(); + if (can_convert) { + return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); + } else { + return NullOpt; + } + } + + template + static tvm::tir::StringImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +template <> +struct PackedFuncValueConverter { + // Common rule for RetValue and ArgValue. Templated to ensure + // correct delegation to `operator std::string()` for either + // TVMArgValue or TVMRetValue. + template + static PrimExpr From(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + // Check against val.TryAsBool directly, to avoid the + // bounds-checking in PackedFuncValueConverter::TryFrom. + return tvm::Bool(opt.value()); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else { + return PrimExpr::FromObject_(val.template AsObjectRef()); + } + } +}; + +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 274ebd0a6558..1d218c6a7c61 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map& param_map); +PrimFunc Specialize(PrimFunc func, const Map>& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9b23973b6f8f..092bd52d5634 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -224,8 +224,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) = 0; + virtual ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 520e0e42ebbe..8f674eea2ec6 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,14 +60,36 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + + # Handle return values that subclass from both TVM objects and + # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) + # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # + # The `hasattr` check is done on the object's class, not the + # object itself, to avoid edge cases that can result in invalid + # error messages. If a C++ `LOG(FATAL) << nested_obj;` statement + # requires C++ to Python conversions in order to print + # `nested_obj`, then the `AttributeError` used internally by + # `hasattr` may overwrite the text being collected by + # `LOG(FATAL)`. By checking for the method on the class instead + # of the instance, we avoid throwing the `AttributeError`. + # if hasattr(type(obj), "__into_pynative_object__"): + # return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 5f3aa04914be..6dab1a5db1f4 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,6 +134,11 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + values[i].v_bool = arg + type_codes[i] = ArgTypeCode.BOOL elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -147,7 +152,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only taeks in bytearray. + # from_buffer only takes in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 38d3cd72b55d..45f36eafd78a 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -27,6 +27,7 @@ class TVMValue(ctypes.Union): _fields_ = [ ("v_int64", ctypes.c_int64), + ("v_bool", ctypes.c_bool), ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p), @@ -94,6 +95,7 @@ def _device_to_int64(dev): RETURN_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, @@ -104,6 +106,7 @@ def _device_to_int64(dev): C_TO_PY_ARG_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 69e1355f7d13..0f7e5fcae6bd 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -16,6 +16,7 @@ # under the License. from ..base import raise_last_ffi_error +from libcpp cimport bool as bool_t from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -38,7 +39,8 @@ cdef enum TVMArgTypeCode: kTVMBytes = 12 kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 - kTVMExtBegin = 15 + kTVMArgBool = 15 + kTVMExtBegin = 16 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: @@ -66,6 +68,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct TVMValue: int64_t v_int64 + bool_t v_bool double v_float64 void* v_handle const char* v_str diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 94a9310d7815..ff38cd3d0ec2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,7 +60,17 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # if hasattr(obj, '__into_pynative_object__'): + # return obj.__into_pynative_object__) + return obj + # return obj.__into_pynative_object__() class PyNativeObject: diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 3d1e87bf563d..7977f37d0be5 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -45,7 +45,7 @@ cdef int tvm_callback(TVMValue* args, tcode == kTVMModuleHandle or tcode == kTVMNDArrayHandle or tcode == kTVMObjectRefArg or - tcode > kTVMExtBegin): + tcode >= kTVMExtBegin): CHECK_CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: @@ -118,6 +118,11 @@ cdef inline int make_arg(object arg, ptr = arg._tvm_handle value[0].v_handle = (ptr) tcode[0] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + value[0].v_bool = arg + tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg tcode[0] = kInt @@ -209,6 +214,8 @@ cdef inline object make_ret(TVMValue value, int tcode): return make_ret_object(value.v_handle) elif tcode == kTVMNullptr: return None + elif tcode == kTVMArgBool: + return value.v_bool elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f148e26f3fcb..03dc18ea6e0b 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,7 +48,8 @@ class ArgTypeCode(object): BYTES = 12 NDARRAY_HANDLE = 13 OBJECT_RVALUE_REF_ARG = 14 - EXT_BEGIN = 15 + BOOL = 15 + EXT_BEGIN = 16 class TVMByteArray(ctypes.Structure): diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py index c2e74eb1935e..b76202a730a2 100644 --- a/python/tvm/driver/tvmc/registry.py +++ b/python/tvm/driver/tvmc/registry.py @@ -20,11 +20,23 @@ from tvm.driver.tvmc import TVMCException -# We can't tell the type inside an Array but all current options are strings so -# it can default to that. Bool is used alongside Integer but aren't distinguished -# between as both are represented by IntImm -INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} -INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} +# We can't tell the type inside an Array but all current options are +# strings so it can default to that. runtime.BoxBool is used to +# distinguish from runtime.BoxInt. +INTERNAL_TO_NATIVE_TYPE = { + "runtime.String": str, + "runtime.BoxBool": bool, + "runtime.BoxFloat": float, + "runtime.BoxInt": int, + "Array": str, +} +INTERNAL_TO_HELP = { + "runtime.String": " string", + "runtime.BoxBool": " bool", + "runtime.BoxInt": " int", + "runtime.BoxFloat": " float", + "Array": " options", +} def _generate_registry_option_args(parser, registry, name): diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 6f0a6dd7d155..6afb383c9f04 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -61,7 +61,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x.value for x in self.__getattr__(key)) + return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) def get_int(self, key): """Get a python int value of a key diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index c70ac2acc71b..263976fa98ff 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -20,7 +20,7 @@ import tvm._ffi -from ..runtime import Object, Scriptable, const, convert +from ..runtime import Object, Scriptable from . import _ffi_api from .base import Node, Span from .type import Type @@ -184,9 +184,6 @@ class Range(Node, Scriptable): def __init__( self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None ) -> None: - if end is None: - end = convert(begin) - begin = const(0, dtype=end.dtype, span=span) self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 6f76452a57b5..51d9a013d8b3 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,6 +28,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule +from tvm.script import tir as T from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -47,7 +48,7 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if isinstance(mod, PrimFunc): if not (mod.attrs and "global_symbol" in mod.attrs): mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) + mod = mod.with_attr("tir.noalias", T.bool(True)) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py index eb44696871eb..502d058ffdf6 100644 --- a/python/tvm/relax/op/statistical.py +++ b/python/tvm/relax/op/statistical.py @@ -195,7 +195,7 @@ def cumprod( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. @@ -213,9 +213,9 @@ def cumprod( Type of the returned array and of the accumulator in which the elements are computed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the product. If + true, the first element is excluded from the product. Returns ------- @@ -247,6 +247,9 @@ def cumprod( cumprod(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 1, 0, 0, 0, 0] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumprod(data, axis, dtype, exclusive) # type: ignore @@ -254,7 +257,7 @@ def cumsum( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. @@ -272,9 +275,9 @@ def cumsum( Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the sum. If + true, the first element is excluded from the sum. Returns ------- @@ -306,6 +309,9 @@ def cumsum( cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumsum(data, axis, dtype, exclusive) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 1ed16363b20a..4c670bbe74b2 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -171,11 +171,19 @@ def visit_call_(self, op: relax.Call) -> str: def display_attrs(attr_key): attr_val = op.attrs[attr_key] - # attrs can be strings but also other types; - # we want to wrap strings in quotes - # (__repr__ would work but it uses single quotes) - attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) - return f"{wrap_quotes(attr_key)}: {attr_str}" + + if isinstance(attr_val, str): + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_val = wrap_quotes(attr_val) + elif isinstance(attr_val, tvm.tir.IntImm): + if attr_val.dtype == "bool": + attr_val = bool(attr_val.value) + else: + attr_val = int(attr_val.value) + + return f"{wrap_quotes(attr_key)}: {attr_val}" fields["attrs"] = self.build_list( map(display_attrs, op.attrs.keys()), diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 71bf8509a63e..aba7ae912c54 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -139,14 +139,14 @@ def _check_well_formed(self, mod: IRModule): # Check function attrs if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm + mod.attrs[self.PARAM_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm + mod.attrs[self.STATE_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 9323bc40da69..e1cab4cbd53b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -97,6 +97,9 @@ def convert_to_expr(value: Any) -> Expr: if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) + if isinstance(value, float): + return PrimValue(tir.FloatImm("float64", value)) + tvm_value = convert_to_object(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 97d7cfa93c8d..199193f75939 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -76,7 +76,7 @@ def get_section_begin_coords(split: tvm.relay.Expr) -> List[int]: # 0 is the beginning of the first section. return [0] + list(indices_or_sections) split_axis_len = input_shape[split_axis].value - section_length = split_axis_len // indices_or_sections.value + section_length = split_axis_len // indices_or_sections return list(range(0, split_axis_len, section_length)) def callback( diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 6b9b311c83b5..dca7b995b22d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Gradient definitions for Relay operators""" +import tvm from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -383,6 +384,8 @@ def concatenate_grad(orig, grad): axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields] splits, cumsum = [], 0 for dim in axis_dims[:-1]: + if isinstance(dim, tvm.tir.IntImm): + dim = dim.value cumsum += dim splits.append(cumsum) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 93df67ff6b99..8bca72655491 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1057,10 +1057,10 @@ def split_shape_func(attrs, inputs, _): return [ _split_shape_func( inputs[0], - convert(i), - convert(indices_or_sections), - convert(param_is_indices), - convert(axis), + i, + indices_or_sections, + param_is_indices, + axis, ) for i in range(num_out) ] diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index dd04d613079b..c4eff3fcc9e0 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1630,10 +1630,10 @@ def __init__(self, func_body): def convert_indices_or_sections(self, indices_or_sections): # split_v if isinstance(indices_or_sections, tvm.ir.container.Array): - values = [i.value for i in indices_or_sections] + values = [int(i) for i in indices_or_sections] # split else: - values = indices_or_sections.value + values = int(indices_or_sections) return values def is_valid(self): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ef1cdb3afdd8..dd9c670e2a37 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,6 +18,8 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +from typing import Optional + from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make @@ -855,13 +857,14 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = list(shape.data.numpy()) - if isinstance(shape, Expr): + shape = shape.data.numpy() + shape = [_expr.IntImm(str(shape.dtype), int(value)) for value in shape] + elif isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) + if isinstance(shape, int): shape = [shape] - if isinstance(shape, (list, tuple)): - shape = list(shape) + return _make.broadcast_to(data, shape) @@ -1938,9 +1941,8 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -def dft(re_data, im_data, inverse=False): - """ - Computes the discrete Fourier transform of input (calculation along the last axis). +def dft(re_data, im_data, inverse: Optional[bool] = False): + """Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. Parameters @@ -1952,8 +1954,11 @@ def dft(re_data, im_data, inverse=False): N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. - inverse : bool + inverse : Optional[bool] + Whether to perform the inverse discrete fourier transform. + Providing None is equivalent to False, and is maintained for + compatibility. Returns ------- @@ -1961,7 +1966,11 @@ def dft(re_data, im_data, inverse=False): The Fourier Transform of the input (Real part). im_output : relay.Expr The Fourier Transform of the input (Imaginary part). + """ + if inverse is None: + inverse = False + return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 7ad838895c9f..6eef6ff3ffae 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -364,9 +364,8 @@ def split(expr, type_map): arg = expr.args[0] t = type_map[arg] attrs = {**expr.attrs} - if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): - num_split = attrs["indices_or_sections"].value - attrs["indices_or_sections"] = num_split + if isinstance(attrs["indices_or_sections"], int): + num_split = attrs["indices_or_sections"] else: num_split = len(attrs["indices_or_sections"]) + 1 return [expr, TupleAffineType([t] * num_split)] diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index f182cd9bfd2f..301f0ef66286 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures -from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple +from .container import String, ShapeTuple # , BoxBool +from .object_generic import convert_to_object, convert, const from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..f1a0706a387d 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -172,3 +172,41 @@ def __eq__(self, other): return False return True + + +# @tvm._ffi.register_object("runtime.BoxBool") +# class BoxBool(Object): +# """A boolean wrapped as a tvm Object + +# Parameters +# ---------- +# value: bool + +# The value to hold +# """ + +# def __init__(self, value: bool): +# # Convert to int to avoid an infinite recursion, because +# # BoxBool may be constructed in _make_tvm_args, and calling +# # the packed func `_ffi_api.BoxBool` internally calls +# # `_make_tvm_args`. +# self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) + +# def __into_pynative_object__(self) -> bool: +# return self.value + +# @property +# def value(self) -> bool: +# """Unwrap the boxed value. + +# This is implemented explicitly rather than using the usual +# PackedFunc handling or AttrVisitor mechanics for two reasons. +# First, because the PackedFunc handling would require ambiguous +# representations between `True`/`1` and `False`/`0`. Second, +# because the boxing/unboxing must be available in +# `libtvm_runtime.so`, and AttrVisitor is only available in +# `libtvm.so`. +# """ +# unboxed_bool = _ffi_api.UnBoxBool(self) +# assert unboxed_bool is not None +# return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 887c2faaeb2b..20909c53c787 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,65 +38,62 @@ def asobject(self): ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject) -def convert_to_object(value, span=None): +def convert_to_object(value): """Convert a Python value to corresponding object type. + Type conversions performed by this function must *only* produce + types that are supported by `libtvm_runtime.so`. This function + must be usable in environments where only TVM runtime support is + present. Automatic conversions to compile-time representations + (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as + part of this conversion, as these types are not available in + `libtvm_runtime.so`. + Parameters ---------- value : str The value to be inspected. - span : Optional[Span] - The location of this itervar in the source code. - Returns ------- obj : Object The corresponding object value. + """ + if isinstance(value, ObjectTypes): return value - if isinstance(value, bool): - return const(value, "uint1x1", span=span) - if isinstance(value, Number): - return const(value, span=span) - if isinstance(value, string_types): + elif isinstance(value, (bool, int, float)): + return value + elif isinstance(value, string_types): return _ffi_api.String(value) - if isinstance(value, (list, tuple)): - value = [convert_to_object(x) for x in value] + elif isinstance(value, (list, tuple)): + # The call to _ffi_api.Array will convert its own arguments, + # so we don't need to apply any explicit conversions here. return _ffi_api.Array(*value) - if isinstance(value, dict): - vlist = [] - for item in value.items(): - if ( - not isinstance(item[0], ObjectTypes) - and not isinstance(item[0], string_types) - and not isinstance(item[0], Number) - ): - raise ValueError("key of map must already been a container type") - vlist.append(convert_to_object(item[0])) - vlist.append(convert_to_object(item[1])) + elif isinstance(value, dict): + if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): + raise ValueError("key of map must already been a container type") + + vlist = [kv for item in value.items() for kv in item] return _ffi_api.Map(*vlist) - if isinstance(value, ObjectGeneric): + elif isinstance(value, ObjectGeneric): return value.asobject() - if callable(value): + elif callable(value): return convert_to_tvm_func(value) - if value is None: + elif value is None: return None - - raise ValueError(f"don't know how to convert type {type(value)} to object") + else: + raise TypeError(f"don't know how to convert type {type(value)} to object") -def convert(value, span=None): +def convert(value): """Convert value to TVM object or function. Parameters ---------- value : python value - span : Optional[Span] - The location of this statement in the source code. - Returns ------- tvm_val : Object or Function @@ -107,29 +104,29 @@ def convert(value, span=None): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ - return convert_to_object(value, span=span) + + return convert_to_object(value) def _scalar_type_inference(value): if hasattr(value, "dtype"): - dtype = str(value.dtype) + return str(value.dtype) elif isinstance(value, bool): - dtype = "bool" + return "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - dtype = "float32" + return "float32" else: - dtype = "float64" + return "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - dtype = "int32" + return "int32" else: - dtype = "int64" + return "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") - return dtype def const(value, dtype=None, span=None): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e545bc3a5e53..3107354ac353 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -536,6 +536,8 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ value = self.eval_expr(node.value) + if value is None: + self.report_error(node, "Expression to be returned must be a PrimExpr") T.evaluate(tvm.tir.ret(value)) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 462066106a9d..948a0d7665ff 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -96,7 +96,7 @@ def _allocate_tensor(func_id, args): ) shape = args[0] for i in shape: - _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression") + _internal_assert(isinstance(i, (_expr.PrimExpr, int)), "The shape should be an expression") if n > 1: _internal_assert(isinstance(args[1], str), "The data type should be an str") _internal_assert( @@ -131,9 +131,11 @@ def len(func_id, args): def _cast(func_id, args): _internal_assert( - args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), - "Only one expression can be cast", + args.__len__() == 1, + f"Casting to {func_id} only supports a single argument", ) + # The FFI can handle any conversion of `args[0]` into PrimExpr, if + # required. return _expr.Cast(func_id, args[0]) @@ -145,9 +147,7 @@ def _cast(func_id, args): def ceil_div(func_id, args): _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") _internal_assert(args.__len__() == 2, "2 arguments expected for division!") - _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div") - _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div") - a, b = args[0], args[1] + a, b = args return (a + b - 1) // b diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 846ef818ea54..bd5a060cd01c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -279,7 +279,7 @@ def visit_Num(self, node): return tvm.runtime.const(node.n, dtype) def visit_NameConstant(self, node): - return tvm.runtime.convert(node.value) + return tvm.tir.const(node.value) def visit_AugAssign(self, node): buf = self.visit(node.target) @@ -376,7 +376,7 @@ def visit_Subscript(self, node): args = [args] arr = self.visit(node.value) - if isinstance(arr, Array): + if isinstance(arr, (Array, list, tuple)): for i in args: if isinstance(i, numbers.Integral): arr = arr[i] diff --git a/python/tvm/te/hybrid/utils.py b/python/tvm/te/hybrid/utils.py index f653b3e83d8b..a515938fa524 100644 --- a/python/tvm/te/hybrid/utils.py +++ b/python/tvm/te/hybrid/utils.py @@ -33,9 +33,9 @@ # pylint: disable=invalid-name -np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) -tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm) +np_arg_types = (numpy.ndarray, *numeric_types) +tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr, *numeric_types, list, tuple, str) +halide_imm_types = (_expr.IntImm, _expr.FloatImm, *numeric_types) def _internal_assert(cond, err): @@ -91,19 +91,13 @@ def replace(op): def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" - if isinstance(args[0], tvm_arg_types): - for elem in args[1:]: - _internal_assert( - isinstance(elem, tvm_arg_types), - f"Expecting a Var, Tensor or ConstExpr instance but {type(elem)} get!", - ) + if all(isinstance(elem, tvm_arg_types) for elem in args): return True - - _internal_assert( - isinstance(args[0], np_arg_types), f"Expect a numpy type but {type(args[0])} get!" - ) - for elem in args[1:]: - _internal_assert( - isinstance(elem, np_arg_types), f"Expect a numpy type but {type(elem)} get!" + elif all(isinstance(elem, np_arg_types) for elem in args): + return False + else: + raise ValueError( + f"Expected arguments to be entirely TVM types, " + f"or entirely numpy types, " + f"but received {[type(elem) for elem in args]}" ) - return False diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index dc2c67849925..64a282dcf755 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -53,7 +53,6 @@ def placeholder(shape, dtype=None, name="placeholder"): tensor: Tensor The created tensor """ - shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape dtype = "float32" if dtype is None else dtype return _ffi_api.Placeholder(shape, dtype, name) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index d435e821acf3..930667242e29 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -64,16 +64,7 @@ def __call__(self, *indices): f"Need to provide {ndim} index in tensor but {len(indices)} was provided" ) indices = convert_to_object(indices) - args = [] - for x in indices: - if isinstance(x, _expr.PrimExpr): - args.append(x) - elif isinstance(x, _expr.IterVar): - args.append(x.var) - else: - raise ValueError("The indices must be expression") - - return _expr.ProducerLoad(self, args) + return _expr.ProducerLoad(self, indices) def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index bcfbe6575d52..0c8048d24d8b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -21,6 +21,7 @@ from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout +from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index c78bb9e7ecd0..37976394f831 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -41,6 +41,10 @@ from .buffer import Buffer, DataProducer +def convert(expr) -> PrimExpr: + return _ffi_api.convert(expr) + + def div_ambiguity_error() -> RuntimeError: return RuntimeError( "TVM supports multiple types of integer divisions, " diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 50de995a9145..777d46ec7b0d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, convert, const +from tvm.runtime import ObjectGeneric, const from tvm.ir import container as _container from . import stmt as _stmt @@ -107,7 +107,9 @@ def __getitem__(self, index): def __setitem__(self, index, value): index = self._normalize_index(index) - value = convert(value) + if isinstance(value, (int, bool, float)): + value = tvm.tir.const(value) + value_element = value.dtype.split("x", maxsplit=1)[0] content_element = self._content_type.split("x", maxsplit=1)[0] if value_element != content_element: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0bc299e403c5..8d9647b60049 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -19,13 +19,14 @@ from typing import Any, Optional, Union import tvm._ffi +from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import const, convert +from tvm.runtime import const from . import _ffi_api from .buffer import Buffer -from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var +from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var def _pack_buffer(buf, span=None): @@ -181,7 +182,7 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, convert(args), span) + return Call(dtype, func_name, args, span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -206,9 +207,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span - ) + return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) def call_extern(dtype, func_name, *args, span=None): @@ -233,9 +232,7 @@ def call_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span - ) + return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span) def call_llvm_intrin(dtype, name, *args, span=None): @@ -1832,13 +1829,10 @@ def dp4a(vec1, vec2, acc=0): call : PrimExpr The call expression. """ - vec1 = convert(vec1) - vec2 = convert(vec2) - acc = convert(acc) return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) -def ret(val): +def ret(val, span=None): """Create a tir return expression Parameters @@ -1846,14 +1840,16 @@ def ret(val): val : Expr The returned tir expression, whose data type is int, float or void pointer. + span : Optional[Span] + The location of this operator in the source code. + Returns ------- ret : PrimExpr The return expression """ - val = convert(val) - return call_intrin(val.dtype, "tir.ret", val) + return _ffi_api.ret(val, span) def any(*args, span=None): @@ -2038,7 +2034,7 @@ def exp(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp", x) @@ -2055,7 +2051,7 @@ def exp2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp2", x) @@ -2072,7 +2068,7 @@ def exp10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp10", x) @@ -2089,7 +2085,7 @@ def erf(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.erf", x) @@ -2106,7 +2102,7 @@ def tanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tanh", x) @@ -2123,7 +2119,7 @@ def sigmoid(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sigmoid", x) @@ -2140,7 +2136,7 @@ def log(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log", x) @@ -2157,7 +2153,7 @@ def log2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log2", x) @@ -2174,7 +2170,7 @@ def log10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log10", x) @@ -2191,7 +2187,7 @@ def log1p(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log1p", x) @@ -2208,7 +2204,7 @@ def tan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tan", x) @@ -2225,7 +2221,7 @@ def cos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cos", x) @@ -2242,7 +2238,7 @@ def cosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cosh", x) @@ -2259,7 +2255,7 @@ def acos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acos", x) @@ -2276,7 +2272,7 @@ def acosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acosh", x) @@ -2293,7 +2289,7 @@ def sin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sin", x) @@ -2310,7 +2306,7 @@ def sinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sinh", x) @@ -2327,7 +2323,7 @@ def asin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asin", x) @@ -2344,7 +2340,7 @@ def asinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asinh", x) @@ -2361,7 +2357,7 @@ def atan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atan", x) @@ -2378,7 +2374,7 @@ def atanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atanh", x) @@ -2398,8 +2394,8 @@ def atan2(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.atan2", x1, x2) @@ -2416,7 +2412,7 @@ def sqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sqrt", x) @@ -2433,7 +2429,7 @@ def rsqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.rsqrt", x) @@ -2679,8 +2675,8 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore @@ -2700,8 +2696,8 @@ def hypot(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore @@ -2721,8 +2717,8 @@ def copysign(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore @@ -2742,8 +2738,8 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore @@ -2862,7 +2858,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def pow(x, y, span=None): @@ -2884,7 +2880,7 @@ def pow(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def popcount(x): @@ -2900,7 +2896,7 @@ def popcount(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.popcount", x) @@ -3032,8 +3028,8 @@ def fmod(x, y): z : PrimExpr The result. """ - x = convert(x) - y = convert(y) + x = tir.convert(x) + y = tir.convert(y) return call_intrin(x.dtype, "tir.fmod", x, y) @@ -3067,7 +3063,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore + return _ffi_api._OpIfThenElse(cond, t, f, span) # type: ignore def div(a, b, span=None): @@ -3314,34 +3310,23 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None, init=None): code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - expr = convert(expr) + expr = tir.convert(expr) if init is not None: - init = convert(init) + init = tir.convert(init) if isinstance(expr, Array): size = len(expr) - larr = [] - rarr = [] + lhs = [] + rhs = [] dtypes = [] for i in range(size): dtype = expr[i].dtype dtypes.append(dtype) lname = code.co_varnames[0] + "_" + str(i) - larr.append(Var(lname, dtype)) + lhs.append(Var(lname, dtype)) rname = code.co_varnames[1] + "_" + str(i) - rarr.append(Var(rname, dtype)) - if init is not None: - init = convert(init) - assert isinstance(init, Array) - assert len(init) == size - for init_i in range(size): - init_i = convert(init_i) - assert isinstance( - init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm) - ) - else: - init = convert([]) - lhs = convert(larr) - rhs = convert(rarr) + rhs.append(Var(rname, dtype)) + if init is None: + init = [] result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: @@ -3352,22 +3337,18 @@ def _make_reduce(expr, axis, where=None, init=None): rvar = Var(code.co_varnames[1], dtype) result = [fcombine(lvar, rvar)] id_elem = [fidentity(dtype)] - lhs = convert([lvar]) - rhs = convert([rvar]) - expr = convert([expr]) + lhs = [lvar] + rhs = [rvar] + expr = [expr] if init is not None: - assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)) - init = convert([init]) - result = convert(result) - id_elem = convert(id_elem) + init = [init] combiner = CommReducer(lhs, rhs, result, id_elem) - axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) + if not isinstance(axis, (list, tuple, tvm.ir.Array)): + axis = [axis] if where is None: - where = convert(True) + where = tir.convert(True) if init is None: - outputs = tuple( - tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size) - ) + outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size)) else: outputs = tuple( tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index cb8d5ce9973e..85377560f1fc 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -39,17 +39,20 @@ def _json_from_tvm(obj): if obj is None: return None - if isinstance(obj, Array): + elif isinstance(obj, (bool, int, float, str)): + return obj + elif isinstance(obj, Array): return [_json_from_tvm(i) for i in obj] - if isinstance(obj, Map): + elif isinstance(obj, Map): return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} - if isinstance(obj, String): + elif isinstance(obj, String): return str(obj) - if isinstance(obj, (IntImm, FloatImm)): + elif isinstance(obj, (IntImm, FloatImm)): return obj.value - if isinstance(obj, IndexMap): + elif isinstance(obj, IndexMap): return save_json(obj) - raise TypeError("Not supported type: " + str(type(obj))) + else: + raise TypeError("Not supported type: " + str(type(obj))) @_register_object("tir.Trace") diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index bf6a9c75516f..cc1a28b9dee0 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -468,7 +468,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype - use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value + use_scalable_vectors = bool(out.op.attrs["use_scalable_vectors"]) tile_M, tile_K = arm_utils.get_tiling_A(False, in_type) tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 83b000a4b9bb..0a7acfa50444 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -295,15 +295,11 @@ def batch_matmul_int8( # pad for _dp4a vectorize pad_x = te.compute( (XB, M, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= XK, tvm.tir.const(0, x.dtype), x[b, i, j]), ) pad_y = te.compute( (YB, N, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= YK, tvm.tir.const(0, y.dtype), y[b, i, j]), ) out = te.compute( diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index 8d59c2a035a9..b98d9c102baa 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -48,7 +48,7 @@ pub struct ModuleNode { crate::external! { #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> i32; + fn runtime_enabled(target: CString) -> bool; #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; @@ -121,8 +121,7 @@ impl Module { /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); - let enabled = runtime_enabled(target).unwrap(); - enabled != 0 + runtime_enabled(target).unwrap() } /// Returns the underlying module handle. diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d..2c1f7db6adb0 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -73,6 +73,7 @@ macro_rules! TVMPODValue { Int(i64), UInt(i64), Float(f64), + Bool(bool), Null, DataType(DLDataType), String(*mut c_char), @@ -95,6 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -117,6 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), @@ -263,6 +266,7 @@ macro_rules! impl_pod_value { impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(Bool, bool, [bool]); impl_pod_value!(DataType, DLDataType, [DLDataType]); impl_pod_value!(Device, DLDevice, [DLDevice]); @@ -380,37 +384,6 @@ impl TryFrom for std::ffi::CString { } } -// Implementations for bool. - -impl<'a> From<&bool> for ArgValue<'a> { - fn from(s: &bool) -> Self { - (*s as i64).into() - } -} - -impl From for RetValue { - fn from(s: bool) -> Self { - (s as i64).into() - } -} - -impl TryFrom for bool { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> bool, - |RetValue::Int(val)| { !(val == 0) }) - } -} - -impl<'a> TryFrom> for bool { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) - } -} - impl From<()> for RetValue { fn from(_: ()) -> Self { RetValue::Null diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e03d4302c89f..82e439cddbc2 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -554,9 +554,19 @@ class FlopEstimator : public ExprFunctor { if (auto pop = op.as()) { if (pop->attrs.count("FLOP")) { // Use user-provided FLOP - auto pint = pop->attrs["FLOP"].as(); - ICHECK(pint != nullptr); - ret += pint->value; + ObjectRef annotation = pop->attrs["FLOP"]; + auto value = [&]() -> int64_t { + if (auto runtime_int = annotation.as()) { + return runtime_int->value; + } else if (auto int_imm = annotation.as()) { + return int_imm->value; + } else { + LOG(FATAL) << "FLOP annotation must be an integer, " + << "but was an object of type " << annotation->GetTypeKey(); + } + }(); + + ret += value; } else { // Estimate by parsing the compute body double num_element = AxisLengthProd(pop->axis); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 862e593c9dd3..0bf6da255d2a 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -482,7 +482,8 @@ std::vector> RuleCustomSketch::Apply(const SketchPolicyNod std::vector> ret; for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); - auto next = item[1].as(); + auto next = item[1].as(); + ICHECK(next); ret.emplace_back(Downcast(item[0]), next->value); } return ret; diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 76fb77dd9527..cc6b0ab23756 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -101,7 +101,7 @@ inline int OperationToStage(const te::Operation& op, const State& state) { /*! \brief Get an integer from a tvm str Map. */ inline int GetIntParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); + auto pint = attr_dict[key].as(); ICHECK(pint != nullptr); return pint->value; } @@ -109,7 +109,7 @@ inline int GetIntParam(const Map& attr_dict, const std::strin /*! \brief Get a double from a tvm str Map. */ inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); + auto pdouble = attr_dict[key].as(); ICHECK(pdouble != nullptr); return pdouble->value; } @@ -120,10 +120,12 @@ inline std::string GetStringParam(const Map& attr_dict, const const auto& target = attr_dict[key]; if (auto pstr = target.as()) { return pstr->value; + } else if (auto pstr = target.as()) { + return pstr->data; + } else { + LOG(FATAL) << "Could not convert object " << target << " of type " << target->GetTypeKey() + << " to string"; } - auto pstr = target.as(); - ICHECK(pstr != nullptr); - return pstr->data; } /*! \brief Get a iterator name set from a tvm str Map. */ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 289c1b79fd66..708fb56c9851 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -100,8 +100,17 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { const ObjectRef& value = doc->value; if (!value.defined()) { output_ << "\"\""; + } else if (const auto* runtime_int = value.as()) { + output_ << runtime_int->value; } else if (const auto* int_imm = value.as()) { output_ << int_imm->value; + } else if (const auto* runtime_float = value.as()) { + output_.precision(config_.float_precision); + if (std::isinf(runtime_float->value) || std::isnan(runtime_float->value)) { + output_ << '"' << runtime_float->value << '"'; + } else { + output_ << runtime_float->value; + } } else if (const auto* float_imm = value.as()) { output_.precision(config_.float_precision); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 7e96c657a711..99be910bd70a 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -33,6 +33,10 @@ namespace msc { LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { if (obj.as()) { return LiteralDoc::Str(Downcast(obj), NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Int(ptr->value, NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Float(ptr->value, NullOpt); } else if (obj.as()) { return LiteralDoc::Int(Downcast(obj)->value, NullOpt); } else if (obj.as()) { diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index f58f95ae53b0..5fcbe924ae1c 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -263,6 +263,10 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = ""; } else if (obj.as()) { obj_string = Downcast(obj); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 105ac063e0ea..1e576bc91002 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -171,9 +171,10 @@ Array CreatePassList(bool disable_loop_partition) { // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { - const IntImmNode* phase_num = phase_pass[0].as(); + auto phase_num = phase_pass[0].as(); ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " + << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index f197ac4416fa..08e7ffc5bf59 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -31,6 +31,91 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } +namespace { + +/* \brief Normalize attributes from runtime types to Relax IR types + * + * While conversion from `tvm::runtime` types to compile-time IR + * types usually occurs as part of FFI conversions, the attributes + * are not converted, as they are stored in a `Map`. While this is required to allow attribute values to + * contain `ObjectRef` instances that are not IR expressions, the + * conversion should still be applied when possible. + * + * \param obj The IR attribute value to be normalized + * + * \return The normalized attribute value + */ +ObjectRef NormalizeAttr(ObjectRef obj) { + if (auto dict_attrs = obj.as()) { + auto new_dict = Downcast>(NormalizeAttr(dict_attrs->dict)); + if (new_dict.same_as(dict_attrs->dict)) { + return obj; + } else { + return DictAttrs(new_dict); + } + } else if (auto runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (auto runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map([](const ObjectRef& inner) { return NormalizeAttr(inner); }); + } else if (auto opt_map = obj.as>()) { + auto map = opt_map.value(); + + Map updates; + for (const auto& [key, inner] : map) { + auto new_inner = NormalizeAttr(inner); + if (!new_inner.same_as(inner)) { + updates.Set(key, new_inner); + } + } + for (const auto& [key, new_inner] : updates) { + map.Set(key, new_inner); + } + + return map; + + } else { + return obj; + } +} +} // namespace + +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { + if (new_attrs.empty()) { + return attrs; + } + + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + + for (const auto& [key, value] : new_attrs) { + attr_dict.Set(key, NormalizeAttr(value)); + } + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.Set(key, NormalizeAttr(value)); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.erase(key); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; @@ -43,11 +128,15 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un dict.Set(key, val.operator PrimExpr()); } } + + dict = Downcast>(NormalizeAttr(dict)); } Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { + dict = Downcast>(NormalizeAttr(dict)); + ObjectPtr n = make_object(); n->dict = std::move(dict); data_ = std::move(n); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 596805f74b24..ded046eafc5d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -47,6 +47,12 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { if (auto opt = ref.as()) { return tir::StringImm(opt.value()); } + if (auto opt = ref.as()) { + return Bool(opt.value()); + } + if (auto opt = ref.as()) { + return Integer(opt.value()); + } if (const auto* buffer_region = ref.as()) { Array indices; indices.reserve(buffer_region->region.size()); @@ -155,9 +161,14 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Range(args[0], args[1], args[2]); -}); +TVM_REGISTER_GLOBAL("ir.Range") + .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { + if (end.defined()) { + return Range(begin, end.value(), span); + } else { + return Range(IntImm(begin->dtype, 0), begin, span); + } + }); TVM_REGISTER_NODE_TYPE(RangeNode); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index dc67822411c5..f0b879acbc03 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -107,43 +107,42 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, uint32_t value_type_index) { + void Register(std::string key, uint32_t value_type_index, + std::function legalization) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_index = value_type_index; info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + info.legalization = legalization; key2vtype_[key] = info; } // Trying to validate and legalize a config. void Legalize(Map* config) { std::vector> update; - auto* reflection = ReflectionVTable::Global(); - - for (auto kv : *config) { - auto it = key2vtype_.find(kv.first); + for (auto [key, obj] : *config) { + auto it = key2vtype_.find(key); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + os << "AttributeError: Invalid config option \'" << key << "\' candidates are:"; int counter = 0; - for (const auto& kv : key2vtype_) { + for (const auto& [key, obj] : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; - os << kv.first; + os << key; } LOG(FATAL) << os.str(); } const auto& info = it->second; - ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; - if (kv.second->IsInstance::ContainerType>()) { - ObjectRef converted = - reflection->CreateObject(info.type_key, Downcast>(kv.second)); - update.emplace_back(kv.first, converted); - } else { - if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { - LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " - << info.type_key << " but get " << kv.second->GetTypeKey(); - } + + ICHECK(obj.defined()) << "AttributeError: " << key << " is None"; + + ICHECK(info.legalization) << "AttributeError: " + << "Config option \'" << key + << "\' was defined without a legalization function."; + auto legalized = info.legalization(obj); + if (!legalized.same_as(obj)) { + update.emplace_back(key, legalized); } } for (auto&& kv : update) { @@ -170,13 +169,15 @@ class PassConfigManager { struct ValueTypeInfo { std::string type_key; uint32_t type_index; + std::function legalization; }; std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { - PassConfigManager::Global()->Register(key, value_type_index); +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization) { + PassConfigManager::Global()->Register(key, value_type_index, legalization); } Map> PassContext::ListConfigs() { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 416753871244..ce025540e496 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -39,8 +39,14 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { } else { os << int_imm->value; } + } else if (const auto* runtime_bool = json_obj.as()) { + os << (runtime_bool->value ? "true" : "false"); + } else if (const auto* runtime_int = json_obj.as()) { + os << runtime_int->value; } else if (const auto* float_imm = json_obj.as()) { os << std::setprecision(20) << float_imm->value; + } else if (const auto* runtime_float = json_obj.as()) { + os << std::setprecision(20) << runtime_float->value; } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { @@ -165,7 +171,7 @@ class JSONTokenizer { std::string to_parse(st, cur_); if (!is_float) { try { - *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))}; + *token = Token{TokenType::kInteger, runtime::Int(std::stoll(to_parse))}; } catch (const std::invalid_argument& e) { LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse << ". Details: " << e.what() << ". Switching to std::stod now."; @@ -178,7 +184,7 @@ class JSONTokenizer { } if (is_float) { try { - *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))}; + *token = Token{TokenType::kFloat, runtime::Float(std::stod(to_parse))}; } catch (const std::invalid_argument& e) { LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse << ". Details: " << e.what(); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 53f680f0a666..63af4a684567 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -192,7 +192,9 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, try { const ArrayNode* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); - workload = workloads[Downcast(arr->at(0)).IntValue()]; + int64_t workload_index = Downcast(arr->at(0)); + ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); + workload = workloads[workload_index]; records[task_id] = TuningRecord::FromJSON(arr->at(1), workload); } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index f5d89a85092b..5b3e6d251d56 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -137,7 +137,7 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + int decision = Downcast(trace->decisions[GetRef(sample_inst)]); std::vector probs = support::AsVector(Downcast>(sample_inst->attrs[1])); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index ea4e81c16f0c..a78b829e34ab 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -129,13 +129,13 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { ICHECK_EQ(inst->attrs.size(), 2); - std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + std::vector probs = support::AsVector( + Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; } - const auto* d = TVM_TYPE_AS(decision, IntImmNode); + const auto* d = TVM_TYPE_AS(decision, runtime::Int::ContainerType); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 7bbf00343af3..36dc57d80e66 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -114,9 +114,9 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = GetRef(sample_inst); candidate->decision = - Downcast(trace->decisions[GetRef(sample_inst)])->value; - candidate->probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = support::AsVector( + Downcast>(sample_inst->attrs[1])); return true; } diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index b651b1f401cb..110cae96cb53 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -34,11 +34,11 @@ using namespace tvm::tir; std::function MakeFactorSampler(Schedule sch, Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { - extents.push_back(extent); + extents.push_back(runtime::Int(extent->value)); } } int n = extents.size(); @@ -48,7 +48,7 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + Array probs(n, runtime::Float(1.0 / n)); return sch->SampleCategorical(extents, probs); }; } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index e8d821636fd3..4a304cefa6bb 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -73,7 +73,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); + Array probs(n_candidate, 1.0 / n_candidate); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -267,7 +267,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + Array thread_extents; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("max_threads_per_block", &max_threads_per_block); @@ -279,8 +279,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { - for (const Integer& extent : thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const auto& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } ObjectPtr n = make_object(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index bcaf4343e256..2979e4229bdd 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -383,9 +383,8 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = - (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), - Array(n, FloatImm(DataType::Float(64), prob))); + tir::ExprRV vector_load_len = (*sch)->SampleCategorical( + support::AsArray(valid_vector_lens), Array(n, prob)); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 045aa85b73ad..8ea2c2d1c6c3 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -68,7 +68,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, FloatImm(DataType::Float(64), prob)); + Array probs(n, runtime::Float(prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -102,7 +102,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -122,7 +122,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + Array unroll_max_steps, bool unroll_explicit) { ObjectPtr n = make_object(); n->max_jobs_per_core = max_jobs_per_core; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 3be264332461..83f5d073cb32 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -79,7 +79,7 @@ Array ScheduleRule::DefaultLLVM() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -126,7 +126,7 @@ Array ScheduleRule::DefaultX86(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -158,11 +158,11 @@ Array ScheduleRule::DefaultCUDA() { /*require_ordered=*/false, /*disallow_op=*/Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, @@ -297,7 +297,7 @@ Array ScheduleRule::DefaultHexagon() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } @@ -410,7 +410,7 @@ Array ScheduleRule::DefaultARM(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ceb0356cbcfe..28c45ea7455d 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -424,13 +424,22 @@ inline Array AsFloatArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), int_imm->value)); - } else if (const auto* float_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), float_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); - } + auto float_value = [&]() -> double { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else if (const auto* float_imm = elem.as()) { + return float_imm->value; + } else if (const auto* runtime_float = elem.as()) { + return runtime_float->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " + << elem->GetTypeKey(); + } + }(); + + results.push_back(FloatImm(DataType::Float(32), float_value)); } return results; } @@ -446,11 +455,16 @@ inline Array AsIntArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(Integer(int_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); - } + auto int_value = [&]() -> int64_t { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); + } + }(); + results.push_back(Integer(int_value)); } return results; } diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc new file mode 100644 index 000000000000..86596fb5ce29 --- /dev/null +++ b/src/node/boxed_primitive.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file node/boxed_primitive.cc + * + * \brief Reflection utilities for runtime-supported classes + * + * The fundamental support for boxing and unboxing of primitives + * during FFI calls is implemented in runtime/boxed_primitive.cc. In + * addition, boxed primitives may be registered with compile-time + * utilities (e.g. reflection, JSON import/export) that can provide + * additional functionality and improved debugging ability. However, + * neither these compile-time utilities nor any registration of + * `Box` into the compile-time utilities should be included as + * part of `libtvm_runtime.so`. + * + * This file contains the registration of the `libtvm_runtime.so` + * class `Box` for utilities that are contained in `libtvm.so`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime_ext { + +using runtime::Box; +using runtime::BoxNode; + +/* \brief Compile-time extension trait for runtime types + * + * Extends the use of boxed primitive during TVM's compilation step. + * + * Most TVM classes define these functions as part of the class + * definition. However, the boxed primitives must be usable at + * runtime, and so the class definition may only refer to types that + * are present in `libtvm_runtime.so`. + */ +template +struct BoxNodeCompileTimeTraits { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { + hash_reduce(node->value); + } + + static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, + SEqualReducer equal) { + return equal(lhs->value, rhs->value); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + int64_t value = std::atoll(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + int64_t value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + if (blob == "true") { + return make_object>(true); + } else if (blob == "false") { + return make_object>(false); + } else { + LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; + } + }) + .set_repr_bytes([](const Object* n) -> std::string { + bool value = GetRef(n).as>().value()->value; + if (value) { + return "true"; + } else { + return "false"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + double value = std::atof(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + double value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +} // namespace runtime_ext + +} // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 6e7d82ee4a59..b8918b4ea48c 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -57,7 +57,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -81,16 +81,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -107,13 +107,13 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } if (auto v = config_dict.Get("show_all_struct_info")) { - n->show_all_struct_info = Downcast(v)->value; + n->show_all_struct_info = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 379a75f6109b..614669a412d0 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -65,6 +65,22 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, return fsequal_reduce_[tindex](self, other, equal); } +namespace { +ObjectPath GetAttrPath(const ObjectRef& obj, const void* attr_address, const ObjectPath& path) { + if (obj->IsInstance() || + obj->IsInstance() || + obj->IsInstance()) { + // Special case for containers that contain boxed primitives. The + // "value" attribute containing the boxed value should not be part + // of the reported mismatched path. + return path; + } else { + Optional attr_key = GetAttrKeyByAddress(obj.get(), attr_address); + return path->Attr(attr_key); + } +} +} // namespace + struct SEqualReducer::PathTracingData { ObjectPathPair current_paths; ObjectRef lhs_object; @@ -72,10 +88,9 @@ struct SEqualReducer::PathTracingData { Optional* first_mismatch; ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { - Optional lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs); - Optional rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs); - return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key), - current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = GetAttrPath(lhs_object, &lhs, current_paths->lhs_path); + ObjectPath rhs_attr_path = GetAttrPath(rhs_object, &rhs, current_paths->rhs_path); + return ObjectPathPair(lhs_attr_path, rhs_attr_path); } }; @@ -98,13 +113,12 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { /* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { - Optional lhs_attr_key = - GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); - Optional rhs_attr_key = - GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); - *tracing_data->first_mismatch = - ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key), - tracing_data->current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = + GetAttrPath(tracing_data->lhs_object, lhs_address, tracing_data->current_paths->lhs_path); + ObjectPath rhs_attr_path = + GetAttrPath(tracing_data->rhs_object, rhs_address, tracing_data->current_paths->rhs_path); + + *tracing_data->first_mismatch = ObjectPathPair(lhs_attr_path, rhs_attr_path); } } @@ -200,7 +214,6 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } // Slow path: tracing object paths for better error reporting - ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 334e6e5c9a62..1c795594629e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,6 +45,7 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; +namespace { // Helper function to get the function name of the registered packed function implementation of // relax operator. FCallPacked GetPackedFuncName(const Call& call) { @@ -57,6 +58,7 @@ FCallPacked GetPackedFuncName(const Call& call) { } return {}; } +} // namespace /*! * \brief A class to generate VM executable for Relax functions. diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index dd34bc63bb31..5e6a1c3f8442 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,6 +44,21 @@ namespace relax_vm { using vm::VMFuncInfo; +namespace { +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} +} // namespace + /*! * \brief A class to generate VMTIR for Relax functions. * @@ -232,7 +247,14 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - if (call_node->op == call_builtin_with_ctx_op_) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + FCallPacked name = GetPackedFuncName(call); + if (name.size()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + EmitCallPacked(name, VisitArray(call->args), dst_reg); + } else if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg); @@ -260,10 +282,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - // turn ndarray cond value into scalar. - cond_value = tir::Cast(DataType::Bool(), - tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), - {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value}); tir::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index fd6fea6e703c..7aca1470aee4 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { +Expr full(Variant> shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 989eaa12fdbf..6e7c8255238a 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype); +Expr full(Variant> shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 07c90756bf90..2b1c6eafb652 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -654,7 +654,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -747,7 +747,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { return ShapeExpr(array_ref); } -Expr reshape(Expr x, ObjectRef shape) { +Expr reshape(Expr x, Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -812,7 +812,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, ObjectRef indices_or_sections, int axis) { +Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 32aa10776894..68622f1359e0 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -90,7 +90,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, ObjectRef shape); +Expr reshape(Expr x, Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -105,7 +105,7 @@ Expr reshape(Expr x, ObjectRef shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, ObjectRef indices_or_sections, int axis); +Expr split(Expr x, Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc index 61b6c9ce897f..345e2d0e60da 100644 --- a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc +++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc @@ -40,7 +40,7 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) { String mcpu = cfg.value()->mcpu; Array mattr = {cfg.value()->mattr}; - Bool debug_last_error = cfg.value()->debug_last_error; + runtime::Bool debug_last_error = cfg.value()->debug_last_error->value; Target cmsis_nn_target(TargetJSON{ {"kind", String("cmsis-nn")}, diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 10125bf814ad..00581a089a4a 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -37,7 +37,7 @@ using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc>("mattr") .add_attr_option("mcpu") - .add_attr_option("debug_last_error") + .add_attr_option("debug_last_error") .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 50c8b84a9069..ea040f6ff56a 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -39,32 +39,32 @@ namespace cutlass { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForCutlass()) // An integer specifying the compute capability. For example, 75 for Turing and // 80 or 86 for Ampere. - .add_attr_option("sm", Integer(80)) + .add_attr_option("sm", runtime::Int(80)) // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for // fp32 inputs on tensorcore. - .add_attr_option("use_3xtf32", Bool(true)) + .add_attr_option("use_3xtf32", runtime::Bool(true)) // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in // parallel across split-K blocks, and a separate global reduction kernel is launched to // accumulate partial reductions. The profiler will pick the best split-k factor from the // given candidate list. Note that the larger split-K factor requires a larger workspace. // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d // kinds, split_k_slices is ignored. - .add_attr_option>("split_k_slices", Array({1})) + .add_attr_option>("split_k_slices", Array{runtime::Int(1)}) // When True, profile all kernel variants with smaller alignments than the largest possible. - .add_attr_option("profile_all_alignments", Bool(false)) + .add_attr_option("profile_all_alignments", runtime::Bool(false)) // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel // is found. - .add_attr_option("find_first_valid", Bool(false)) + .add_attr_option("find_first_valid", runtime::Bool(false)) // Whether to compile profiler executables for different kernels in parallel. - .add_attr_option("use_multiprocessing", Bool(false)) + .add_attr_option("use_multiprocessing", runtime::Bool(false)) // Number of threads to use during compilation, or -1 to use number of cpus. - .add_attr_option("threads", Integer(-1)) + .add_attr_option("threads", runtime::Int(-1)) // Whether to replace sigmoid with tanh. - .add_attr_option("use_fast_math", Bool(false)) + .add_attr_option("use_fast_math", runtime::Bool(false)) // A temporary directory where intermediate compiled artifacts will be stored. .add_attr_option("tmp_dir", String("./tmp")); diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index a3f3e6e1eb6e..0f539d96e919 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -687,14 +687,14 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) { sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat, params->input_info.m_QuantizationInfo); params->split_info.m_Axis = attrs->axis; - if (attrs->indices_or_sections->IsInstance()) { - auto sections = Downcast(attrs->indices_or_sections)->value; + if (const auto* sections_ptr = attrs->indices_or_sections.as()) { + auto sections = sections_ptr->value; int size = input_tensor_shape[attrs->axis] / sections; for (int i = 0; i < sections; i++) { params->split_info.m_Sizes.push_back(size); } } else { - auto indices = Downcast>(attrs->indices_or_sections); + auto indices = Downcast>(attrs->indices_or_sections); int last_index = 0; for (const auto& i : indices) { params->split_info.m_Sizes.push_back(i->value - last_index); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 54d0595c4634..300372838416 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -307,8 +307,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { Array compile_artifacts; for (const auto& kv : mod->functions) { const tir::PrimFunc& prim_func = Downcast(kv.second); - Optional> params = - prim_func->GetAttr>("ethos-u.constants"); + auto params = prim_func->GetAttr>("ethos-u.constants"); ICHECK(params) << "microNPU params should be present"; auto primfunc_to_artifact_pf = tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 23a873b2d392..d87447f863e2 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -97,7 +97,7 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr CreateSplitReshapedTensors(const Expr& input, const Array& original_args) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; Array rets; int total_size = 0; @@ -132,7 +132,7 @@ class ExternalFuncIOHandler : public ExprRewriter { if (func->params.size() > 1) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; auto func_name = gv->name_hint; int total_size = 0; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index b45987f6be33..de9c81a2706e 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -38,6 +38,6 @@ TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) - .add_attr_option("example_attribute", Integer(0)); + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index f4babad50a3e..1dd5e3a4d772 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -177,12 +177,12 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector indices_or_sections; std::vector mode; std::vector axis = {std::to_string(split_attr->axis)}; - if (const auto* sections = split_attr->indices_or_sections.as()) { + if (const auto* sections = split_attr->indices_or_sections.as()) { mode.emplace_back("sections"); indices_or_sections.emplace_back(std::to_string(sections->value)); } else { mode.emplace_back("indices"); - auto indices = Downcast>(split_attr->indices_or_sections); + auto indices = Downcast>(split_attr->indices_or_sections); for (const auto& i : indices) { indices_or_sections.emplace_back(std::to_string(i->value)); } diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index 0277787a8c12..a62dc25e329c 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -38,30 +38,30 @@ namespace tensorrt { * - Runtime: src/runtime/contrib/tensorrt/... */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForTensorRT()) // A array of three integers given the major, minor, and patch numbers for the supported // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. - .add_attr_option>("tensorrt_version", Array()) + .add_attr_option>("tensorrt_version", Array()) // If true, the first tensor dimension for most operators is allowed to be Any and // TensorRT will assume it represents a batch dimension only known at inference time. // Fewer Relay operators are supported in implicit batch mode. Default true. - .add_attr_option("use_implicit_batch", Bool(true)) + .add_attr_option("use_implicit_batch", runtime::Bool(true)) // If true, excludes sub-graphs which do not have multiply-accumulate operations, even though // TensorRT supports them. ad. This is a simple heuristic to optimize the partitioning between // TensorRT and TVM. Not required if using Collage for partitioning. Defalut false. - .add_attr_option("remove_no_mac_subgraphs", Bool(false)) + .add_attr_option("remove_no_mac_subgraphs", runtime::Bool(false)) // How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. // Default 1G. - .add_attr_option("max_workspace_size", Integer(1 << 30)) + .add_attr_option("max_workspace_size", runtime::Int(1 << 30)) // If true, allows TensorRT to automatically convert float32 operations to float16. Must also be // enabled if any float16 operations are in the model. Note that TensorRT may still choose a // higher-precision kernel if it results in overall lower runtime, or if no low-precision // implementation exists. Default false. - .add_attr_option("use_fp16", Bool(false)) + .add_attr_option("use_fp16", runtime::Bool(false)) // If true, allows TensorRT to automatically convert float32 operations to uint8 // (aka quantized). Default false. - .add_attr_option("use_uint8", Bool(false)); + .add_attr_option("use_uint8", runtime::Bool(false)); } // namespace tensorrt } // namespace contrib diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc index 244f243749c1..0499c0bba198 100644 --- a/src/relay/backend/contrib/uma/targets.cc +++ b/src/relay/backend/contrib/uma/targets.cc @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") .add_attr_option("model") .add_attr_option>("libs") .add_attr_option("host") - .add_attr_option("from_device") + .add_attr_option("from_device") .set_attr( attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name)) .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); @@ -75,8 +75,9 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") } if (default_value->IsInstance()) { target_kind.add_attr_option(option_name, Downcast(default_value)); - } else if (default_value->IsInstance()) { - target_kind.add_attr_option(option_name, Downcast(default_value)); + } else if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, + Downcast(default_value)); } else { LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. " << "Given attribute option type: " << attr_option.second->GetTypeKey(); diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 1d6caecb87ba..66feac4699e6 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -89,13 +89,13 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { /********** Register Executors and options **********/ TVM_REGISTER_EXECUTOR("aot") - .add_attr_option("link-params", Bool(true)) - .add_attr_option("unpacked-api") + .add_attr_option("link-params", runtime::Bool(true)) + .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constant-byte-alignment"); + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constant-byte-alignment"); -TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", runtime::Bool(false)); /********** Registry **********/ diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 923c9b2d5f65..0534298ea44d 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 0c0ff7290115..3e86e1c8eaf9 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -73,6 +73,42 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp } bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + // Unwrapping arrays may find user-provided FFI types in the + // attributes (e.g. Defining pad_value as ((0,0), (0,0)) will result + // in runtime::Int. These need to be converted to compile-time IR + // types when encountered. + if (lhs->IsInstance() || + lhs->IsInstance() || + lhs->IsInstance()) { + TVMRetValue lhs_convert; + lhs_convert = lhs; + PrimExpr lhs_expr = lhs_convert; + return MatchRetValue(lhs_expr, rhs); + } + + // StructuralEqual doesn't check for conversions between FFI types + // and IR types, but the pattern-matcher should. Therefore, + // explicitly recurse into the array. + if (auto opt_lhs_array = lhs.as>()) { + if (Optional> opt_rhs_array = rhs) { + Array lhs_array = opt_lhs_array.value(); + Array rhs_array = opt_rhs_array.value(); + if (lhs_array.size() != rhs_array.size()) { + return false; + } + for (size_t i = 0; i < lhs_array.size(); i++) { + TVMRetValue rhs_item; + rhs_item = rhs_array[i]; + if (!MatchRetValue(lhs_array[i], rhs_item)) { + return false; + } + } + return true; + } else { + return false; + } + } + switch (rhs.type_code()) { case kDLInt: if (auto* val = lhs.as()) { diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 50d8531c7dd0..222aba4bd25b 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -79,7 +79,7 @@ Expr MakeReshape(Expr data, Array newshape, bool allowzero = false); Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end); -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); +Expr MakeSplit(Expr data, Variant> indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fde6daa4d851..96f833d80505 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2984,10 +2984,10 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, Layout ret = Layout::Undef(); size_t size = 0; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { size = sections->value; } else { - size = Downcast>(param->indices_or_sections).size() + 1; + size = Downcast>(param->indices_or_sections).size() + 1; } // If new_in_layouts are defined, this code tries to modify the layout. @@ -2998,13 +2998,12 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, param->axis = new_index; int factor = new_in_layouts[0].FactorOf(sp_dim); if (factor > 1) { - if (!param->indices_or_sections.as()) { - auto ios = Downcast>(param->indices_or_sections); - Array new_ios; + if (!param->indices_or_sections.as()) { + auto ios = Downcast>(param->indices_or_sections); + Array new_ios; for (const auto& v : ios) { - const IntImmNode* vint = v.as(); - new_ios.push_back(vint->value / factor); - if (vint->value % factor) { + new_ios.push_back(runtime::Int(v->value / factor)); + if (v->value % factor) { divisible = false; } } @@ -3041,7 +3040,7 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; ICHECK_GE(axis, 0) << "axis should be within the input dimension range."; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { if (!data->shape[axis].as()) { ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == tir::make_zero(DataType::Int(64)))) @@ -3061,8 +3060,8 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TupleType(Array(fields))); } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; @@ -3097,19 +3096,20 @@ Array SplitCompute(const Attrs& attrs, const Array& inpu const auto param = attrs.as(); ICHECK(param != nullptr); - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { +Expr MakeSplit(Expr data, Variant> indices_or_sections, + int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -3117,17 +3117,7 @@ Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = - MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } -}); +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body_typed(MakeSplit); RELAY_REGISTER_OP("split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. @@ -4157,11 +4147,13 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - attrs->exclusive = exclusive; + if (exclusive.defined()) { + attrs->exclusive = exclusive.value(); + } static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index a41e1e0d6674..74827f166b51 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -159,7 +159,7 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; - auto split = MakeSplit(data, Integer(branches.size()), 0); + auto split = MakeSplit(data, runtime::Int(branches.size()), 0); for (const auto& branch : branches) { auto split_data = TupleGetItem(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 34f986b251a2..df28506c6217 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -266,7 +266,7 @@ class ConstantFolder : public MixedModeMutator { // always use graph executor with no link-params dict.Set(tvm::attr::kExecutor, - relay::Executor::Create("graph", {{"link-params", Bool(false)}})); + relay::Executor::Create("graph", {{"link-params", runtime::Bool(false)}})); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index edf1e4c99f4d..da7a8f6420cd 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -36,8 +36,6 @@ namespace tvm { namespace relay { -using namespace tvm::runtime; - /*! What is automatic differentiation(AD) and why is it important? * By AD, we roughly mean, given a term which denotes some mathematical function, * derive a term which denotes the derivative of that mathematical function. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 5026b1bcba79..1112755b76a0 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array op_descriptor = + Array> op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc new file mode 100644 index 000000000000..9ab83a7b471c --- /dev/null +++ b/src/runtime/boxed_primitive.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/boxed_primitive.cc + * \brief Implementations of ObjectRef wrapper. + */ + +#include +#include + +namespace tvm { +namespace runtime { + +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); + +/* \brief Allow explicit construction of Box + * + * Convert a `bool` to `Box`. For use in FFI handling, to + * provide an umambiguous representation between `bool(true)` and + * `int(1)`. Will be automatically unboxed in the case where a + * `Box` is provided to a PackedFunc that requires `int` input, + * mimicking C++'s default conversions. + * + * This is only needed for Box, as Box and Box + * can be converted in C++ as part of `TVMArgValue::operator + * ObjectRef()` without ambiguity, postponing conversions until + * required. + */ +TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); + +/* \brief Return the underlying boolean object. + * + * Used while unboxing a boolean return value during FFI handling. + * The return type is intentionally `int` and not `bool`, to avoid + * recursive unwrapping of boolean values. + * + * This is only needed for Box, as Box and Box + * can be unambiguously unboxed as part of + * `TVMRetValue::operator=(ObjectRef)`. + */ +TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { + return obj->value; +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 57979b160ea7..04d36ad8bcab 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -361,14 +361,18 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r TVMAPISetLastError("ModuleGetFunction expects second argument to be a string"); return kTvmErrorFunctionCallWrongArgType; } - if (type_codes[2] != kDLInt) { + + if (type_codes[2] == kDLInt) { + query_imports = args[2].v_int64 != 0; + } else if (type_codes[2] == kTVMArgBool) { + query_imports = args[2].v_bool; + } else { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; } mod = (TVMModuleHandle)args[0].v_handle; name = args[1].v_str; - query_imports = args[2].v_int64 != 0; to_return = TVMModGetFunction(mod, name, query_imports, &ret_value->v_handle); if (to_return == 0) { diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 493bc3fb1dc9..f7204e372f6d 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -102,10 +102,10 @@ DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { int cnt = 0; for (int i = 3; i < num_args; ++i) { int type_code = type_codes[i]; - if (type_code != kDLInt && type_code != kDLUInt && type_code != kDLFloat && - type_code != kTVMDataType && type_code != kDLDevice && type_code != kTVMOpaqueHandle && - type_code != kTVMStr && type_code != kTVMNullptr && type_code != kTVMBytes && - type_code != kTVMObjectHandle) { + if (type_code != kDLInt && type_code != kDLUInt && type_code != kTVMArgBool && + type_code != kDLFloat && type_code != kTVMDataType && type_code != kDLDevice && + type_code != kTVMOpaqueHandle && type_code != kTVMStr && type_code != kTVMNullptr && + type_code != kTVMBytes && type_code != kTVMObjectHandle) { os << "\n Argument #" << i - 3 << " has unsupported type code: " << type_code << " (" << ArgTypeCode2Str(type_code) << ")"; cnt += 1; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index d08dadb02bb9..485ebdb449da 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -325,6 +325,10 @@ struct RPCReference { channel->template Write(value.v_int64); break; } + case kTVMArgBool: { + channel->template Write(value.v_bool); + break; + } case kTVMDataType: { channel->Write(value.v_type); // padding @@ -432,6 +436,10 @@ struct RPCReference { channel->template Read(&(value.v_int64)); break; } + case kTVMArgBool: { + channel->template Read(&(value.v_bool)); + break; + } case kTVMDataType: { channel->Read(&(value.v_type)); int32_t padding = 0; diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 3908ad1112a0..9fe6fba80f5c 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -279,7 +279,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional err_ctx) { - if (dtype.is_bool()) { + if (arg.IsObjectRef()) { + ObjectRef obj = arg.AsObjectRef(); + LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype + << ", but received ObjectRef of type " << obj->GetTypeKey(); + } else if (dtype.is_bool()) { arg.operator bool(); } else if (dtype.is_int()) { arg.operator int64_t(); @@ -426,7 +430,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device") * \return Bool */ bool ReadIfCond(TVMArgValue cond) { - if (cond.type_code() == kDLInt) return cond.operator bool(); + if (cond.type_code() == kDLInt || cond.type_code() == kTVMArgBool) { + return cond.operator bool(); + } NDArray arr = cond.operator tvm::runtime::NDArray(); if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 54194e7e2a41..61bdec680a29 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -323,12 +323,33 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } else if (const auto* float_imm = value.as()) { // TODO(yelite): Make float number printing roundtrippable - output_.precision(17); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { output_ << '"' << float_imm->value << '"'; + } else if (std::nearbyint(float_imm->value) == float_imm->value) { + // Special case for floating-point values which would be + // formatted using %g, are not displayed in scientific + // notation, and whose fractional part is zero. + // + // By default, using `operator<<(std::ostream&, double)` + // delegates to the %g printf formatter. This strips off any + // trailing zeros, and also strips the decimal point if no + // trailing zeros are found. When parsed in python, due to the + // missing decimal point, this would incorrectly convert a float + // to an integer. Providing the `std::showpoint` modifier + // instead delegates to the %#g printf formatter. On its own, + // this resolves the round-trip errors, but also prevents the + // trailing zeros from being stripped off. + std::showpoint(output_); + std::fixed(output_); + output_.precision(1); + output_ << float_imm->value; } else { + std::defaultfloat(output_); + std::noshowpoint(output_); + output_.precision(17); output_ << float_imm->value; } + } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index ef68b89b5bf4..686f486da6eb 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -30,6 +30,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return LiteralDoc::Str(s, p); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Bool obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Boolean(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Int obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Int(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Float obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Float(obj->value, p); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 6f9a8cbf8918..35a9f35db491 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -75,7 +75,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - return LiteralDoc::Int(n->value, n_p); + if (n->dtype.is_bool()) { + return LiteralDoc::Boolean(n->value, n_p); + } else { + return LiteralDoc::Int(n->value, n_p); + } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/support/array.h b/src/support/array.h index 0ca57a2410c5..0d4c8134787b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -164,12 +164,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -177,12 +179,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -191,11 +195,13 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : array) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -221,8 +227,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -233,8 +241,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int64_t x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -245,8 +255,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (double x : vec) { - result.push_back(FloatImm(tvm::DataType::Float(64), x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index aec57a1eb20d..928cdfcab80b 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,6 +189,58 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { + return arg; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") + .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") + .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { + return map[key]; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") + .set_body_typed([](Map map) -> ObjectRef { return map; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { + return expr; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") + .set_body_typed([](Array arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance()) + << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") + .set_body_typed([](Array> arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance() || item->IsInstance()) + << "Array contained " << item->GetTypeKey() + << " when it should contain either PrimExpr or PackedFunc"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") + .set_body_typed([](Map map) -> ObjectRef { + for (const auto& kv : map) { + ObjectRef value = kv.second; + CHECK(value->IsInstance()) + << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; + } + return map; + }); + /** * Simple event logger that can be used for testing purposes */ diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 481ba39cc7b1..21899a12c4b0 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -347,18 +347,26 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); - ICHECK(t.is_handle() || t.bits() == 64); - if (t.is_int()) { + if (t.is_bool()) { + // The stride between adjacent entries is still + // `sizeof(TVMValue)==64`, even if the enum currently holds a + // boolean. + buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_int64_, buf, index); + buf = builder_->CreatePointerCast(buf, DTypeToLLVMType(t)->getPointerTo()); + return TypedPointer(t_int8_, buf); + } else if (t.is_int() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float()) { + } else if (t.is_float() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else { - ICHECK(t.is_handle()); + } else if (t.is_handle()) { buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); + } else { + LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMValue"; } } default: @@ -1366,9 +1374,16 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref.addr, t_void_p_); - } else { - return builder_->CreateLoad(ref.type, ref.addr); } + + llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); + + if (op->dtype == DataType::Bool()) { + struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); + } + + return struct_value; + } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index dd5a3fb681ee..0406dcf951bb 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -294,10 +294,10 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.MCOptions.ABIName = Downcast(target.Get("mabi")); } - auto maybe_level = Downcast(target.Get("opt-level")); + auto maybe_level = target.Get("opt-level").as(); #if TVM_LLVM_VERSION <= 170 if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOpt::None; } else if (level == 1) { @@ -313,7 +313,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } #else if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOptLevel::None; } else if (level == 1) { @@ -333,8 +333,12 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // Fast math options - auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { - return Downcast(target.Get(flag.str()).value_or(Bool(false))); + auto GetBoolFlag = [&target](llvm::StringRef name) -> bool { + if (auto flag = target.Get(name.str())) { + return Downcast(flag); + } else { + return false; + } }; if (GetBoolFlag("fast-math")) { #if TVM_LLVM_VERSION >= 60 diff --git a/src/target/tag.cc b/src/target/tag.cc index 9eca3072df0e..d45bf61a38f1 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -76,61 +76,61 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}, + {"num-cores", runtime::Int(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}}}}); + {"num-cores", runtime::Int(4)}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(6)}}}}); + {"num-cores", runtime::Int(6)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(12)}}}}); + {"num-cores", runtime::Int(12)}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET @@ -139,10 +139,10 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"kind", String("cuda")}, \ {"keys", Array{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"max_threads_per_block", Integer(1024)}, \ - {"thread_warp_size", Integer(32)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"max_threads_per_block", runtime::Int(1024)}, \ + {"thread_warp_size", runtime::Int(32)}, \ + {"registers_per_block", runtime::Int(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -158,9 +158,9 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", runtime::Int(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(52428800)); + .with_config("l2_cache_size_bytes", runtime::Int(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -263,7 +263,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", runtime::Int(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); @@ -416,7 +416,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", Integer(Cores)}}); + {"num-cores", runtime::Int(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -432,9 +432,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"thread_warp_size", Integer(WarpSize)}, \ + {"max_threads_per_block", runtime::Int(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"thread_warp_size", runtime::Int(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index cd2e3714e422..a8337b58ae9b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,24 +359,31 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex() || + info.type_index == runtime::Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer or boolean std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Bool is a subclass of IntImm, so allow textual boolean values. + // Mimic C++ automatic conversions, allowing bool to be used for + // integer parameters. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); + throw Error(": Cannot parse integer from string: " + interp_str); } } - return Integer(v); + + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return runtime::Int(v); + } else { + return runtime::Bool(v); + } } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -410,13 +417,13 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "Integer")); - } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return GetRef(ObjTypeCheck(obj, "runtime.BoxInt")); + } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -483,7 +490,11 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { + if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -494,7 +505,7 @@ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { } return u; } - LOG(FATAL) << "Cannot stringify this object"; + LOG(FATAL) << "Cannot stringify object of type " << obj->GetTypeKey(); } std::string TargetInternal::StringifyArray(const ArrayNode& array) { @@ -953,7 +964,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")).IntValue(); + int device_id = Downcast(attrs.at("from_device"))->value; attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1006,38 +1017,13 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - switch (ret.type_code()) { - case kTVMNullptr: - // Nothing returned for this parameter, move on to the next one. - continue; - - case kTVMArgInt: - if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Integer(static_cast(ret)); - } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Bool(static_cast(ret)); - } else { - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received integer from device api"; - } - break; - - case kTVMStr: - ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) - << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received string from device api"; - output[key] = String(ret.operator std::string()); - break; - - default: - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; - break; + // Delegate conversion from TVMRetValue to the FFI's default conversions. + if (Optional opt = ret) { + output[key] = opt.value(); } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 708d3ccd7621..fced74c3a559 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -243,7 +243,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", Bool(true)}}; + Map features = {{"is_test", runtime::Bool(true)}}; target.Set("features", features); return target; } @@ -256,16 +256,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit @@ -273,7 +273,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -301,28 +301,29 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", runtime::Int(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + runtime::Int(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", Integer(1024)) - .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("max_num_threads", runtime::Int(1024)) + .add_attr_option("thread_warp_size", runtime::Int(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -332,24 +333,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(65536)) - .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(65536)) + .add_attr_option("thread_warp_size", runtime::Int(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(16384)) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("texture_spatial_limit", Integer(16384)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(16384)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("texture_spatial_limit", runtime::Int(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", Integer(128)) + .add_attr_option("max_function_args", runtime::Int(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -358,55 +359,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(32768)) - .add_attr_option("thread_warp_size", Integer(16)) - .add_attr_option("max_function_args", Integer(31)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(32768)) + .add_attr_option("thread_warp_size", runtime::Int(16)) + .add_attr_option("max_function_args", runtime::Int(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", runtime::Bool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", runtime::Bool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_num_threads", runtime::Int(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -423,8 +424,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 5797d2295bab..fb839c28da96 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -56,10 +56,25 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode* op); -inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); +static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { + const char* shared_text = + "When a TE compute node produces multiple outputs, " + "each of which is a reduction, " + "each reduction must be structurally identical, " + "except for the ReduceNode::value_index. "; + + StructuralEqual eq; + + ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " + << a->combiner << " does not match " << b->combiner; + ICHECK(a->source.same_as(b->source)) + << shared_text << "However, the input " << a->source << " does not match " << b->source; + ICHECK(eq(a->axis, b->axis)) << shared_text << "However, the reduction axis " << a->axis + << " does not match " << b->axis; + ICHECK(eq(a->condition, b->condition)) << shared_text << "However, the predicate " << a->condition + << " does not match " << b->condition; + ICHECK(eq(a->init, b->init)) << shared_text << "However, the initial value " << a->init + << " does not match " << b->init; } int ComputeOpNode::num_outputs() const { return body.size(); } @@ -529,8 +544,7 @@ class ComputeVerifier final : protected tir::ExprVisitor { << "with being Reduce operation or not."; if (reduce && reduce_) { - ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + AssertReduceEqual(reduce, reduce_); } level_ = 0; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 2eb0693685a6..b5a87d9446d8 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -355,11 +355,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Array seq_stmt; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - return a->combiner.same_as(b->combiner) && // - a->source.same_as(b->source) && // - a->axis.same_as(b->axis) && // - a->condition.same_as(b->condition) && // - ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); + StructuralEqual eq; + return eq(a->combiner, b->combiner) && // + eq(a->source, b->source) && // + eq(a->axis, b->axis) && // + eq(a->condition, b->condition) && // + eq(a->init, b->init); }; PrimExpr expr_body = compute_op->body[0]; @@ -370,7 +371,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in const tir::ReduceNode* reduce_ = compute_op->body[k].as(); ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should have the same attribute except value_index"; + << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " + << "but the first argument has body " << GetRef(reduce_) << ", while the " << k + << "-th argument has body " << GetRef(reduce); tensors.push_back(compute_op.output(k)); } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 4f5df7ad3024..774a0f8f1f89 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -63,7 +63,17 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") - .set_body_typed([](Array shape, DataType dtype, std::string name) { + .set_body_typed([](Variant> shape_arg, DataType dtype, + std::string name) { + auto shape = [&]() -> Array { + if (auto arg_expr = shape_arg.as()) { + return {arg_expr.value()}; + } else if (auto arg_array = shape_arg.as>()) { + return arg_array.value(); + } else { + LOG(FATAL) << "Variant did not contain either allowed type"; + } + }(); return placeholder(shape, dtype, name); }); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c38c5a5c800b..1ad8914e48cc 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -124,9 +124,10 @@ void ReplaceDataFlow(const Array& stages, std::unordered_mapcombiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); + StructuralEqual struct_equal; + return struct_equal(a->combiner, b->combiner) && struct_equal(a->source, b->source) && + struct_equal(a->axis, b->axis) && struct_equal(a->condition, b->condition) && + struct_equal(a->init, b->init); } Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 3a41c5ac5a25..70e82a605369 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -134,7 +134,7 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); if (target.defined() && target->kind->name == "hexagon") { - auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; + auto value = target->GetAttr("vtcm-capacity").value()->value; if (value > 0) return value; } return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1506082003fd..c38237a664f7 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -35,6 +35,18 @@ namespace tvm { namespace tir { +/* \brief Convert an object to a PrimExpr + * + * All conversions to a PrimExpr are performed as part of the FFI, + * when calling a function that accepts a PrimExpr as an argument. If + * a function must normalize to a PrimExpr (e.g. before accessing the + * `expr.dtype` field), this function allows the FFI conversions to be + * explicitly invoked. + */ +TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { + return expr; +}); + #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -546,7 +558,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || @@ -707,9 +721,11 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis if (!init.empty()) { ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; for (size_t i = 0; i < init.size(); i++) { + ICHECK(init[i].defined()) << "Init value must be defined"; ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || init[i]->IsInstance()) - << "init can only be a IntImm, FloatImm or ProducerLoad"; + << "init can only be a IntImm, FloatImm or ProducerLoad, " + << "but received " << init[i] << " of type " << init[i]->GetTypeKey(); } } n->dtype = source[value_index].dtype(); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 14dd0eadb65c..2c94b9d8646b 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -27,6 +27,8 @@ #include #include +#include "utils.h" + namespace tvm { namespace tir { namespace { @@ -79,6 +81,11 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, if (!ret_type.defined()) { ret_type = VoidType(); } + + if (attrs.defined()) { + attrs = Downcast(NormalizeAttributeObject(attrs)); + } + auto n = make_object(); n->params = std::move(params); n->body = std::move(body); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index b30d0caf6af3..78fb9365cc71 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map& param_map) { +PrimFunc Specialize(PrimFunc func, const Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 5df76450ff1e..9c8f580b5413 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -27,6 +27,7 @@ #include #include "buffer_common.h" +#include "utils.h" namespace tvm { namespace tir { @@ -61,6 +62,15 @@ TVM_REGISTER_NODE_TYPE(LetStmtNode); // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + // The nodes are not required to be a TIR type, and may legally + // contain any ObjectRef. However, normalizing to an IR type if + // possible prevents spurious discrepancies in StructuralEqual(). + if (auto opt = node.as()) { + node = Bool(opt.value()); + } else if (auto opt = node.as()) { + node = Integer(opt.value()); + } + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -109,13 +119,21 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding, Map annotations, Span span) { + ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); - ICHECK(min.dtype().is_scalar()); - ICHECK(extent.dtype().is_scalar()); - ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); + auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) { + auto dtype = expr.dtype(); + CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) + << "TIR For nodes require a scalar integer as the " << field_name << ", but received " + << expr << " with dtype " << dtype; + }; + require_scalar_int_dtype(loop_var, "loop_var"); + require_scalar_int_dtype(min, "min"); + require_scalar_int_dtype(extent, "extent"); + // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them // without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { @@ -136,6 +154,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); @@ -234,6 +254,8 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -288,6 +310,8 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -652,6 +676,8 @@ Block::Block(Array iter_vars, Array reads, Array init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc new file mode 100644 index 000000000000..0e3dc1237894 --- /dev/null +++ b/src/tir/ir/utils.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/utils.cc + * \brief Utilities for manipulating TIR + */ +#include "utils.h" + +#include + +namespace tvm { +namespace tir { + +ObjectRef NormalizeAttributeObject(ObjectRef obj) { + if (const auto* runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (const auto* runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (const auto* runtime_float = obj.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map(NormalizeAttributeObject); + } else if (auto opt_map = obj.as>()) { + Map new_map; + bool is_same = true; + + for (const auto& [key, obj] : opt_map.value()) { + ObjectRef new_obj = NormalizeAttributeObject(obj); + is_same = is_same && obj.same_as(new_obj); + new_map.Set(key, new_obj); + } + + if (is_same) { + return obj; + } else { + return new_map; + } + } else if (auto dict_attrs = obj.as()) { + auto new_attrs = Downcast>(NormalizeAttributeObject(dict_attrs->dict)); + if (new_attrs.same_as(dict_attrs->dict)) { + return GetRef(dict_attrs); + } else { + return DictAttrs(new_attrs); + } + } else { + return obj; + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h new file mode 100644 index 000000000000..b1f7a722899f --- /dev/null +++ b/src/tir/ir/utils.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/utils.h + * \brief Utilities for manipulating TIR + */ +#ifndef TVM_TIR_IR_UTILS_H_ +#define TVM_TIR_IR_UTILS_H_ + +#include + +namespace tvm { +namespace tir { + +/* \brief Normalize an ObjectRef held + * + * Where possible, the IR should be normalized contain IR types. For + * example, holding a `tir::IntImm` instead of a `runtime::Int`. In + * attributes, this is not always possible, as attributes may refer to + * non-IR objects. + * + * This function normalizes any `runtime::Int`, `runtime::Bool`, + * `runtime::Float`, or containers of those types to the corresponding + * IR type. + * + * \param obj The attribute object to be normalized + * + * \returns The normalized attribute + */ +ObjectRef NormalizeAttributeObject(ObjectRef obj); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_UTILS_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index c79a148e4b6e..dad4ea98d614 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -229,9 +229,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } PrimExpr ret(PrimExpr value, Span span) { + CHECK(value.defined()); return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } +TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; @@ -1048,12 +1051,15 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double(), args[2]); + if (auto opt = args[0].TryAsInt()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsBool()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsFloat()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); } else { - LOG(FATAL) << "only accept int or float"; // FIXME + LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " + << "but instead received argument with type code " << args[0].type_code(); // FIXME } }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index cda501cd992e..73b5ff3fafd4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -233,9 +233,9 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); @@ -914,6 +914,14 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ if (ann_val.as()) { return ann_val; } + if (auto* runtime_int = ann_val.as()) { + return IntImm(DataType::Int(32), runtime_int->value); + } else if (auto* runtime_float = ann_val.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto* runtime_bool = ann_val.as()) { + return Bool(runtime_bool->value); + } + if (const auto* expr = ann_val.as()) { ICHECK(!ann_val->IsInstance()) << "TypeError: runtime::String is expected, but gets StringImm"; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 4eccff10a2c7..092bcf0c79f9 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -87,8 +87,9 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) override; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 122c5ff0d9fe..9209e6578687 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -439,6 +439,11 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os } else if (const auto* float_imm = obj.as()) { os.precision(17); os << float_imm->value; + } else if (const auto* runtime_int = obj.as()) { + os << runtime_int->value; + } else if (const auto* runtime_float = obj.as()) { + os.precision(17); + os << runtime_float->value; } else if (const auto* array = obj.as()) { os << '['; bool is_first = true; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fe1c1850dcd5..fd1349e4a3ec 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,9 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision); + const Array& candidates, + const Array& probs, + Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 92c3423bcbbb..4c7b208e964f 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../ir/utils.h" #include "../utils.h" namespace tvm { @@ -97,6 +98,8 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { + ann_val = NormalizeAttributeObject(ann_val); + if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 2a2f17355ca6..8e16f50b8b95 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,19 +163,18 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const Array& candidates, const Array& probs, + Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { - const auto* int_imm = decision->as(); - i = int_imm->value; + i = decision->value()->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - std::vector weights = support::AsVector(probs); + std::vector weights = support::AsVector(probs); std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); @@ -183,8 +182,8 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } - *decision = Integer(i); // decision is guaranteed not to be nullptr. - return candidates[i].IntValue(); + *decision = runtime::Int(i); // decision is guaranteed not to be nullptr. + return candidates[i]->value; } std::function MakeMultinomialSampler( @@ -461,24 +460,11 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { - Array probs_float = probs.Map([](const ObjectRef& prob) { - const auto* prob_float = prob.as(); - if (prob_float != nullptr) { - return GetRef(prob_float); - } - const auto* prob_int = prob.as(); - if (prob_int != nullptr) { - return FloatImm(DataType::Float(32), static_cast(prob_int->value)); - } - LOG(FATAL) - << "SampleCategorical does not accept probability with type other than float or int."; - throw; - }); - return sch->SampleCategorical(candidates, probs_float, decision); + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 4b10df7e9728..6e243bf19198 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -112,7 +112,9 @@ Array TranslateInputRVs( } else if (const auto* str_obj = input.as()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); - } else if (input->IsInstance() || input->IsInstance()) { + } else if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { // Case 3. integer or floating-point number results.push_back(input); } else if (input->IsInstance()) { @@ -149,7 +151,9 @@ Array TranslateInputRVs(const Array& inputs, results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { // Case 3. integer or floating-point number - if (input->IsInstance() || input->IsInstance()) { + if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { results.push_back(input); continue; } @@ -388,9 +392,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { try { const ArrayNode* arr = decision_entry.as(); ICHECK(arr && arr->size() == 2); - const IntImmNode* arr0 = arr->at(0).as(); + auto arr0 = arr->at(0).as(); ICHECK(arr0); - index = arr0->value; + index = arr0.value(); decision = arr->at(1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 16c4350aaee6..1611109d7735 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 686d84ebc6fe..78629e84f039 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,8 +47,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) final; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index cc33ba9f86c2..14672f568549 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map param_map; + Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 423b0ca92237..2948773321dd 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,6 +155,7 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; + if (t.is_bool()) return DataType::Bool(); if (t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1a3888a7cd48..1cde4f2ebe7d 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -511,6 +511,8 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } else if (IsArrayHandle(arg)) { arg_tcode = kTVMDLTensorHandle; + } else if (arg.dtype().is_bool()) { + arg_tcode = kTVMArgBool; } // opaque handle need to set the kind properly if (arg_tcode == kTVMOpaqueHandle) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d327cdfa8393..9f2f1295fece 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -263,15 +263,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](DataType t, int i) { + auto f_arg_value = [&](DataType arg_type, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version - DataType api_type = APIType(t); + DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. - if (api_type != t) { - res = Cast(t, res); + if (api_type != arg_type) { + res = Cast(arg_type, res); } return res; }; @@ -319,10 +319,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { continue; } - var_def.emplace_back(f_arg_value(param.dtype(), i), param); - if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(param, func_ptr->buffer_map[param]); - } + PrimExpr arg_value; // type code checks Var tcode(param->name_hint + ".code", DataType::Int(32)); @@ -335,15 +332,45 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } else if (t.is_bool()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be boolean"; + seq_init.emplace_back( + AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgBool, + f_arg_value(DataType::Bool(), i), + cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), + }); + } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back( + AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgInt, + f_arg_value(t, i), + cast(t, f_arg_value(DataType::Bool(), i)), + }); } else { ICHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } + + var_def.emplace_back(arg_value, param); + if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc index 53ea7e39ed59..adabb9b9b6cf 100644 --- a/tests/cpp/relay/backend/runtime_test.cc +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -26,13 +26,13 @@ namespace tvm { namespace relay { TVM_REGISTER_RUNTIME("TestRuntime") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") .add_attr_option("another_option") - .add_attr_option("defaulty_the_default_option", Bool(false)); + .add_attr_option("defaulty_the_default_option", runtime::Bool(false)); TEST(Runtime, Create) { - Map attrs = {{"my_bool", Bool(true)}}; + Map attrs = {{"my_bool", runtime::Bool(true)}}; Runtime my_runtime = Runtime::Create("TestRuntime", attrs); ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); @@ -40,7 +40,7 @@ TEST(Runtime, Create) { } TEST(Runtime, UnknownAttr) { - Map attrs = {{"woofles", Bool(true)}}; + Map attrs = {{"woofles", runtime::Bool(true)}}; ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); } @@ -64,7 +64,7 @@ TEST(RuntimeRegistry, ListRuntimeOptions) { Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); ICHECK_EQ(attrs.empty(), false); - ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["my_bool"], "runtime.BoxBool"); ICHECK_EQ(attrs["your_names"], "Array"); ICHECK_EQ(attrs["another_option"], "runtime.String"); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2db4b572bf60..0a2b8206d322 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,15 +32,15 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { String mcpu = Downcast(target.at("mcpu")); target.Set("mcpu", String("super_") + mcpu); target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", Bool(true)}}); + target.Set("features", Map{{"test", runtime::Bool(true)}}); return target; } @@ -76,14 +76,14 @@ TEST(TargetKind, GetAttrMap) { TEST(TargetCreation, NestedConfig) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -91,13 +91,14 @@ TEST(TargetCreation, NestedConfig) { ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); ICHECK_EQ(target->tag, ""); ICHECK(target->keys.empty()); - Bool my_bool = target->GetAttr("my_bool").value(); + runtime::Bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool.operator bool(), true); Array your_names = target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = target->GetAttr>("her_maps").value(); + Map her_maps = + target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); @@ -105,15 +106,15 @@ TEST(TargetCreation, NestedConfig) { TEST(TargetCreationFail, UnrecognizedConfigOption) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -133,9 +134,9 @@ TEST(TargetCreationFail, TypeMismatch) { {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -150,13 +151,13 @@ TEST(TargetCreationFail, TypeMismatch) { TEST(TargetCreationFail, TargetKindNotFound) { Map config = { - {"my_bool", Bool("true")}, + {"my_bool", runtime::Bool("true")}, {"your_names", Array{"junru", "jian"}}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -178,15 +179,16 @@ TEST(TargetCreation, TargetParser) { TEST(TargetCreation, TargetFeatures) { Target test_target_with_parser("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); + ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); Target test_target_no_parser("TestTargetKind"); - ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); - ASSERT_EQ(test_target_no_parser->GetFeature("test", Bool(true)).value(), true); + ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); + ASSERT_EQ(test_target_no_parser->GetFeature("test", runtime::Bool(true)).value(), + true); } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", Bool(true)}}; + Map features = {{"test", runtime::Bool(true)}}; Map config = { {"kind", String("TestTargetParser")}, {"mcpu", String("woof")}, @@ -469,13 +471,13 @@ TEST(TargetCreation, DetectSystemTriple) { #endif TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) .set_attr(tvm::attr::kRelayToTIR, diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index bbfb8bd2db12..f5b1651e115a 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """Test packed function FFI.""" +import gc + +import numpy as np + import tvm from tvm import te import tvm.testing -import numpy as np +from tvm.script import tir as T def test_get_global(): @@ -37,7 +41,7 @@ def my_packed_func(*args): def test_get_callback_with_node(): - x = tvm.runtime.convert(10) + x = T.int32(10) def test(y): assert y.handle != x.handle @@ -66,7 +70,7 @@ def add(x): myf = tvm.runtime.convert(addy) f = myf(10) - assert f(11).value == 21 + assert f(11) == 21 def test_convert(): @@ -113,6 +117,14 @@ def test_device_func(dev): def test_rvalue_ref(): def callback(x, expected_count): + # The use count of TVM objects is decremented as part of + # `ObjectRef.__del__`, which runs when the Python object is + # destructed. However, Python object destruction is not + # deterministic, and even CPython's reference-counting is + # considered an implementation detail. Therefore, to ensure + # correct results from this test, `gc.collect()` must be + # explicitly called. + gc.collect() assert expected_count == tvm.testing.object_use_count(x) return x diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index afd716cde389..42f5b0ccd0b8 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -16,16 +16,27 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir +from tvm.script import tir as T class CanonicalChecker: def __init__(self): self.analyzer = tvm.arith.Analyzer() + def _convert(self, expr): + # TODO(Lunderberg): Make utility functions `tir.convert` and + # `relax.convert` that convert to their respective IR types. + # Implementation should be in C++, and should only consist of + # conversions that are applied automatically through FFI. + if isinstance(expr, int): + return T.int32(expr) + else: + return expr + def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - expected = tvm.runtime.convert(expected) + expected = self._convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected ) @@ -377,13 +388,13 @@ def test_simplify_normalize_min_value_expr(): x = te.var("x", "int32") ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) - ck.verify(te.min_value("int32") + x == 0, False) + ck.verify(te.min_value("int32") + x == 0, tir.const(False)) ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) - ck.verify(0 == te.min_value("int32") + x, False) + ck.verify(0 == te.min_value("int32") + x, tir.const(False)) ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) - ck.verify(x + te.min_value("int32") == 0, False) + ck.verify(x + te.min_value("int32") == 0, tir.const(False)) ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) - ck.verify(0 == x + te.min_value("int32"), False) + ck.verify(0 == x + te.min_value("int32"), tir.const(False)) def test_proddiv_simplify(): diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index 3a10ec05efeb..f0e6f05adfad 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm.tir import floordiv, floormod +from tvm.script import tir as T def ifuse(inputs, pred_extent=None): @@ -537,7 +538,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) - tvm.ir.assert_structural_equal(res[1][1], True) + tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) # compound 1 i0 = create_iter("i0", 4) @@ -553,7 +554,7 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) @@ -569,7 +570,7 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -587,11 +588,11 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) - tvm.ir.assert_structural_equal(res[2][1], True) + tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 @@ -606,9 +607,9 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - tvm.ir.assert_structural_equal(res[2][0], True) + tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -642,10 +643,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices @@ -661,9 +662,9 @@ def test_subspace_division(): assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map( @@ -690,10 +691,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) @@ -735,8 +736,8 @@ def test_subspace_divide_trivial_iters(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], x) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], y) diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py index d38fe70f6b5c..0aa353c60041 100644 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ b/tests/python/arith/test_arith_narrow_predicate_expression.py @@ -20,6 +20,7 @@ from tvm import tir from tvm.runtime import convert +from tvm.script import tir as T i = tir.Var("i", "int32") @@ -42,18 +43,18 @@ [i < n, i < 0], [i <= n, i <= 0], [i >= n, i >= 7], - [n > i, convert(0) > i], - [n < i, convert(7) < i], - [n <= i, convert(7) <= i], - [n >= i, convert(0) >= i], - [i == n, tir.all(i <= 0, convert(7) <= i)], - [n == i, tir.all(convert(7) <= i, i <= 0)], - [i != n, tir.any(i < 0, convert(7) < i)], - [n != i, tir.any(convert(7) < i, i < 0)], + [n > i, T.int32(0) > i], + [n < i, T.int32(7) < i], + [n <= i, T.int32(7) <= i], + [n >= i, T.int32(0) >= i], + [i == n, tir.all(i <= 0, T.int32(7) <= i)], + [n == i, tir.all(T.int32(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, T.int32(7) < i)], + [n != i, tir.any(T.int32(7) < i, i < 0)], [i // 4 > n, i // 4 > 7], - [n < i // 4, convert(7) < i // 4], + [n < i // 4, T.int32(7) < i // 4], [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, T.int32(0) <= tir.Add(i, 0) // 4)], [i + n < 10, i + 7 < 10], [i - n < 10, tir.Sub(i, 0) < 10], [tir.Not(i < n), tir.Not(i < 7)], diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 90f0aeef47d7..7fc1862192d6 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -27,6 +27,8 @@ from tvm.tir import truncdiv as tdiv from tvm.tir import truncmod as tmod +from tvm.script import tir as T + class TestCase: def __init__(self, before, expected, preconditions=None): @@ -35,10 +37,21 @@ def __init__(self, before, expected, preconditions=None): if isinstance(expected, tir.expr.EqualOp): expected = expected.asobject() - self.before = before - self.expected = expected + self.before = self._convert(before) + self.expected = self._convert(expected) self.preconditions = preconditions + @staticmethod + def _convert(expr): + if isinstance(expr, tir.expr.EqualOp): + return expr.asobject() + elif isinstance(expr, int): + return T.int32(expr) + elif isinstance(expr, float): + return T.float32(expr) + else: + return expr + @property def constraint(self): if self.preconditions is None: @@ -1008,8 +1021,8 @@ class TestComparisons(BaseCompare): TestCase(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), # Rewrite based on definition of integer division - TestCase(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), - TestCase(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5)), + TestCase(tir.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), + TestCase(tir.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), # Narrow upper bound using floormod TestCase(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), TestCase(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), @@ -1025,36 +1038,36 @@ class TestComparisons(BaseCompare): # Merge a known floordiv and an upper bound of floormod into a value range TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(50) <= x, x < 57), + tir.all(T.int32(50) <= x, x < 57), ), TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(50) <= x, x <= 57), + tir.all(T.int32(50) <= x, x <= 57), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(-50) <= x, x < -43), + tir.all(T.int32(-50) <= x, x < -43), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(-50) <= x, x <= -43), + tir.all(T.int32(-50) <= x, x <= -43), ), # Merge a known floordiv and an lower bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(57) < x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(57) < x, x < 60), ), TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(57) <= x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(57) <= x, x < 60), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(-43) < x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(-43) < x, x < -40), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(-43) <= x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(-43) <= x, x < -40), ), TestCase(tvm.te.min(x, 11) < 10, x < 10), TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")), @@ -1224,14 +1237,16 @@ class TestIfThenElse(BaseCompare): class TestCLZ(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), T.int32(32)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), T.int32(31)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), T.int32(30)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), T.int32(24)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), T.int32(64)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), T.int32(63)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), T.int32(62)), + TestCase( + tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), T.int32(56) + ), ) diff --git a/tests/python/arith/test_arith_solve_linear_equations.py b/tests/python/arith/test_arith_solve_linear_equations.py index 24eb860c55f6..3195a4ae514f 100644 --- a/tests/python/arith/test_arith_solve_linear_equations.py +++ b/tests/python/arith/test_arith_solve_linear_equations.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T def test_solution_consistency(): @@ -109,8 +110,8 @@ def test_unique_solution(): [x, y], ) assert list(solution.dst.variables) == [] - assert ir.structural_equal(solution.src_to_dst[x], 15) - assert ir.structural_equal(solution.src_to_dst[y], 5) + assert ir.structural_equal(solution.src_to_dst[x], T.int32(15)) + assert ir.structural_equal(solution.src_to_dst[y], T.int32(5)) def test_low_rank(): @@ -128,7 +129,7 @@ def test_low_rank(): [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) - assert ir.structural_equal(solution.src_to_dst[z], 5) + assert ir.structural_equal(solution.src_to_dst[z], T.int32(5)) def test_infer_range(): @@ -149,12 +150,12 @@ def test_infer_range(): assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range - assert ir.structural_equal(solution.dst.ranges[n0].min, -9) - assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) + assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9)) + assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) - assert ir.structural_equal(ineq.a, -5) + assert ir.structural_equal(ineq.a, T.int32(-5)) assert ir.structural_equal(ineq.b, n0) @@ -172,7 +173,7 @@ def test_ill_formed(): ) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py b/tests/python/arith/test_arith_solve_linear_inequality.py index 5285da12e75d..664258ae7cf1 100644 --- a/tests/python/arith/test_arith_solve_linear_inequality.py +++ b/tests/python/arith/test_arith_solve_linear_inequality.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11458") @@ -113,10 +114,10 @@ def test_dual_variable(): [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) - assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) + assert ir.structural_equal(solution.dst.ranges[x_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, T.int32(11)) + assert ir.structural_equal(solution.dst.ranges[y_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, T.int32(6)) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) @@ -185,7 +186,7 @@ def test_no_solution(): solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 112c521d06d4..112d1151febd 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -769,7 +769,7 @@ def check_cuda(dtype, n, l, padding, lanes): (n // lanes, l + 2 * padding, lanes), lambda i, j, k: tvm.te.if_then_else( tvm.te.any(j < padding, j >= l + padding), - tvm.runtime.convert(0).astype(dtype), + tvm.tir.const(0, dtype), A[i * lanes + k, j - padding], ), name="B", diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f50d63878e4f..d9a6fd6e62d1 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1138,5 +1138,46 @@ def func(): tvm.build(func) +def test_int_parameter(): + """Boolean may be passed to functions accepting int""" + + @T.prim_func + def func(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg > 0: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(True) + assert output == 10 + + output = built(False) + assert output == 20 + + +def test_bool_parameter(): + """Integers may be passed to functions accepting bool""" + + @T.prim_func + def func(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(1) + assert output == 10 + + output = built(2) + assert output == 10 + + output = built(0) + assert output == 20 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 61511c609ca4..238a77b4ef4b 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -56,20 +56,20 @@ def get_first_mismatch_ensure_symmetry(a, b): ( [1, 2, 3], [1, 4, 3], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], [10, 2, 30], - ObjectPath.root().array_index(0).attr("value"), - ObjectPath.root().array_index(0).attr("value"), + ObjectPath.root().array_index(0), + ObjectPath.root().array_index(0), ), ( [1, 3, 4], [1, 2, 3, 4], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], @@ -121,14 +121,28 @@ def test_shape_tuple_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None +@pytest.mark.parametrize( + "contents", + [ + {}, + {"a": 1, "b": 2}, + {"a": True, "b": False}, + ], +) +def test_string_map_structural_equal_to_self(contents): + a = tvm.runtime.convert({**contents}) + b = tvm.runtime.convert({**contents}) + assert get_first_mismatch_ensure_symmetry(a, b) is None + + @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ ( dict(a=3, b=4), dict(a=3, b=5), - ObjectPath.root().map_value("b").attr("value"), - ObjectPath.root().map_value("b").attr("value"), + ObjectPath.root().map_value("b"), + ObjectPath.root().map_value("b"), ), ( dict(a=3, b=4), diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index aa482dd65cd7..1e3249197851 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,16 +23,19 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1].value == 3 + assert a[-1] == 3 a_slice = a[-3:-1] - assert (a_slice[0].value, a_slice[1].value) == (1, 2) + assert (a_slice[0], a_slice[1]) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3]) + a = tvm.runtime.convert([1, 2, 3.5, True]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1].value == 2 + assert a_loaded[1] == 2 + assert a_loaded[2] == 3.5 + assert a_loaded[3] == True + assert isinstance(a_loaded[3], bool) def test_dir_array(): @@ -66,7 +69,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"].value == 2 + assert amap["a"] == 2 assert "a" in dd assert "b" in dd @@ -78,7 +81,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1].value for kv in amap.items()} + dd = {kv[0].name: kv[1] for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 2355aa19adec..b70406c1bb7a 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -16,6 +16,7 @@ # under the License. """Test type nodes in the IR""" import tvm +from tvm.script import tir as T def check_json_roundtrip(node): @@ -38,11 +39,9 @@ def test_tensor_type_bad_constructor(): def test_tensor_type(): - shape = tvm.runtime.convert([1, 2, 3]) - dtype = "float32" - tt = tvm.ir.TensorType(shape, dtype) - assert tt.dtype == dtype - assert tt.shape == shape + tt = tvm.ir.TensorType([1, 2, 3], "float32") + assert tt.dtype == "float32" + assert list(tt.shape) == [T.int32(1), T.int32(2), T.int32(3)] assert tt.span == None str(tt) check_json_roundtrip(tt) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index f1709c449d16..b0ddbe93601e 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" ) @@ -144,7 +144,7 @@ def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer(( vi, vj = T.axis.remap("SS", [i, j]) T.reads(x[vi, vj]) T.writes(y[vi, vj]) - y[vi, vj] = x[vi, vj] + T.float32(1) + y[vi, vj] = x[vi, vj] + T.float32(1.0) @R.function def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 97ad9f5dd034..64d5c7381171 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -404,7 +404,7 @@ def f( "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', "args": '[Var(name_hint="x"), Var(name_hint="y")]', "sinfo_args": "[ObjectStructInfo()]", - "attrs": '{"test_attr": 1}', + "attrs": '{"test_attr": True}', }, extern_call_text, ) diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 2ab5afaabf24..1efbd690f034 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -63,6 +63,13 @@ def foo(x: R.Tensor((2, 3), "float32", "llvm")): def test_dispatch_scanop_cuda(): + """R.cumsum and R.cumprod may be lowered with TOPI for GPU + + For the purpose of testing, this test case intentionally uses the + `exclusive=True` argument to prevent the `R.cumsum` from being + lowered to the packed func `"gpu_2d_continuous_cumsum"`. + """ + @I.ir_module class Before: I.module_global_infos({"vdevice": [I.vdevice("cuda", 0)]}) @@ -70,7 +77,7 @@ class Before: @R.function def main(x: R.Tensor(("m", 3), "float32", "cuda")): with R.dataflow(): - lv0 = R.cumsum(x, axis=1) + lv0 = R.cumsum(x, axis=1, exclusive=True) lv1 = R.cumprod(lv0, axis=1) gv = lv1 R.output(gv) @@ -89,6 +96,7 @@ def main(x: R.Tensor(("m", 3), "float32", "cuda")): topi.cuda.cumsum, x, axis=1, + exclusive=True, ) out = bb.emit_te( topi.cuda.cumprod, diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 7b64eb1dee39..e93547d83e3c 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -395,7 +395,7 @@ def test_call_tir_with_grad(): """ v0: R.Tensor((54, 96), dtype="float32") x = T.int64() -R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": T.float32(1), "x": x}) +R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) """, ) @@ -758,7 +758,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -770,7 +770,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index ab40e181a35a..30fd06d4f14d 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -566,7 +566,7 @@ def main(shape: R.Prim(value="n")): assert func(2) == 4 - with pytest.raises(tvm.TVMError): + with pytest.raises(TypeError): func(ShapeTuple([2])) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 9a4817f5fd8a..60f096585dfe 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -118,9 +118,10 @@ class Expected: @T.prim_func def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__ife"}) - if T.cast( - T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + if T.Call( "bool", + tvm.ir.Op.get("tir.tvm_call_packed"), + ["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))], ): T.anylist_setitem_call_packed( r, diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 4031790fc383..b79713e05ed3 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -18,6 +18,7 @@ import numpy as np import tvm +from tvm.script import tir as T from tvm import relay from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import * @@ -115,7 +116,7 @@ def test_DataTypePattern(): def test_ShapePattern(): - shape = [10, 10] + shape = [T.int32(10), T.int32(10)] pattern = has_shape(shape) assert isinstance(pattern, ShapePattern) tvm.ir.assert_structural_equal(pattern.shape, shape) diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py index d703ef1f3d9a..04662f21ae9e 100644 --- a/tests/python/relay/test_executor.py +++ b/tests/python/relay/test_executor.py @@ -57,7 +57,7 @@ def test_create_executor_attr_type_incorrect(): with pytest.raises( TVMError, match='Attribute "interface-api" should have type "runtime.String"' - ' but instead found "IntImm"', + ' but instead found "runtime.BoxBool"', ): Executor("aot", {"interface-api": True}) diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py index ea15dd0d3c88..db8252f3a3c4 100644 --- a/tests/python/relay/test_runtime.py +++ b/tests/python/relay/test_runtime.py @@ -51,7 +51,7 @@ def test_create_runtime_attr_not_found(): def test_create_runtime_attr_type_incorrect(): with pytest.raises( TVMError, - match='Attribute "system-lib" should have type "IntImm"' + match='Attribute "system-lib" should have type "runtime.BoxBool"' ' but instead found "runtime.String"', ): Runtime("crt", {"system-lib": "woof"}) @@ -65,7 +65,7 @@ def test_list_runtimes(): def test_list_runtime_options(runtime): aot_options = Runtime.list_registered_options(runtime) assert "system-lib" in aot_options - assert aot_options["system-lib"] == "IntImm" + assert aot_options["system-lib"] == "runtime.BoxBool" def test_list_runtime_options_not_found(): diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index f18994d52ce9..7d0cd51d3298 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -18,12 +18,13 @@ for expressions. """ import pytest +import numpy as np + import tvm -from tvm import IRModule, parser, relay, te -from tvm.relay import analysis, op, transform +from tvm import IRModule, relay +from tvm.relay import op, transform from tvm.relay.op import op as _op - -import numpy as np +from tvm.script import tir as T def infer_mod(mod, annotate_spans=True): @@ -554,40 +555,32 @@ def test_repeat_register(): assert "Operator custom_log3 is registered before" in str(cm.execption) -def test_argreduce_infer_return_type(): +@pytest.mark.parametrize("relay_op", [relay.op.argmax, relay.op.argmin]) +@pytest.mark.parametrize( + "shape_dtype", + [ + ("int32", T.int32), + ("int64", T.int64), + ], + ids=["int32", "int64"], +) +def test_argreduce_infer_return_type(relay_op, shape_dtype): x_shape = (1, 1) broadcast_shape = [1, 1] - shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] - - # Testing with argmax - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmax = relay.op.argmax(broadcast_to, axis=[1]) - - f = relay.Function([x], argmax) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) - - # Testing with argmin - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmin = relay.op.argmin(broadcast_to, axis=[1]) - - f = relay.Function([x], argmin) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) + (sdtype, conv) = shape_dtype + + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmax = relay_op(broadcast_to, axis=[1]) + + f = relay.Function([x], argmax) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 7538075ae7f8..e0d216b33e9a 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np +import pickle import random + +import numpy as np + import tvm import tvm.testing -import pickle -from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -96,8 +97,123 @@ def test_shape_tuple(): assert stuple == z +def test_bool_argument(): + """Boolean objects are currently stored as int""" + func = tvm.get_global_func("testing.AcceptsBool") + + assert isinstance(func(True), bool) + assert isinstance(func(1), bool) + assert isinstance(func(0), bool) + + +def test_int_argument(): + func = tvm.get_global_func("testing.AcceptsInt") + + assert isinstance(func(True), int) + assert isinstance(func(1), int) + assert isinstance(func(0), int) + + +def test_object_ref_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRef") + + assert isinstance(func(True), bool) + assert isinstance(func(1), int) + assert isinstance(func(3.5), float) + assert func(3.5) == 3.5 + + +def test_object_ref_array_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRefArray") + + assert isinstance(func([True, 17, "hello"]), bool) + assert isinstance(func([True]), bool) + assert isinstance(func([17]), int) + assert isinstance(func(["hello"]), str) + + +def test_map_argument_returns_value(): + func = tvm.get_global_func("testing.AcceptsMapReturnsValue") + + res = func({"a": 1, "b": 2}, "a") + assert isinstance(res, int) + assert res == 1 + + res = func({"a": True, "b": False}, "a") + assert isinstance(res, bool) + assert res == True + + +def test_map_argument_returns_map(): + func = tvm.get_global_func("testing.AcceptsMapReturnsMap") + + res = func({"a": 1, "b": 2}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, int) + + res = func({"a": False, "b": True}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, bool) + + +def test_conversion_of_arg(): + """Arguments may be converted + + The calling side of the FFI converts to types that are available + at runtime. However, there may be additional type conversions + required, that must be performed on the callee-side of the FFI. + """ + + func = tvm.get_global_func("testing.AcceptsPrimExpr") + + res = func(1) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "int32" + + res = func(True) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "bool" + + +def test_conversion_of_array_elements(): + """Elements of an array may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to array elements. Here, the Python-side of the FFI + converts the array `[1,2]` to `Array{runtime::Int(1), + runtime::Int(2)}`, and the C++ side of the FFI converts to + `Array{IntImm(1), IntImm(2)}`. + """ + + func = tvm.get_global_func("testing.AcceptsArrayOfPrimExpr") + + res = func([1, False]) + assert isinstance(res[0], tvm.tir.IntImm) + assert res[0].dtype == "int32" + assert isinstance(res[1], tvm.tir.IntImm) + assert res[1].dtype == "bool" + + +def test_conversion_of_map_values(): + """Elements of a map may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to map elements. Here, the Python-side of the FFI + converts the map `{'a':1, 'b':2}` to `Map{{"a", runtime::Int(1)}, + {"b", runtime::Int(2)}}`, and the C++ side of the FFI converts to + `Map{{"a", IntImm(1)}, {"b", IntImm(2)}}`. + """ + + func = tvm.get_global_func("testing.AcceptsMapOfPrimExpr") + + res = func({"a": 1, "b": False}) + assert isinstance(res["a"], tvm.tir.IntImm) + assert res["a"].dtype == "int32" + assert isinstance(res["b"], tvm.tir.IntImm) + assert res["b"].dtype == "bool" + + if __name__ == "__main__": - test_string() - test_adt_constructor() - test_tuple_object() - test_shape_tuple() + tvm.testing.main() diff --git a/tests/python/te/test_te_schedule_tensorize.py b/tests/python/te/test_te_schedule_tensorize.py index 79aecb78902a..419d3edb5c3d 100644 --- a/tests/python/te/test_te_schedule_tensorize.py +++ b/tests/python/te/test_te_schedule_tensorize.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T def intrin_vadd(xo, m, n): @@ -100,6 +101,7 @@ def add(m): def check(m, factor): x, y, z = add(m) + factor = T.int32(factor) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) vadd = intrin_vadd(xo, m, factor) @@ -133,7 +135,7 @@ def check_cache_write(m, factor): finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z_global], dom_map) # outer loop var will be rebased, so min value is the new loop var and extent is 1 - tvm.ir.assert_structural_equal(out_dom[xo].extent, 1) + tvm.ir.assert_structural_equal(out_dom[xo].extent, T.int32(1)) assert isinstance(out_dom[xo].min, tvm.tir.Var) assert xo.var.name == out_dom[xo].min.name @@ -183,7 +185,7 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -207,7 +209,7 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -230,7 +232,7 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -254,7 +256,7 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -264,10 +266,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) - check(16) - check_rfactor(16, 16) - check_rfactor_no_reset(16, 16) - check_rfactor_no_reset_multi_reduction(16, 16) + check(T.int32(16)) + check_rfactor(T.int32(16), T.int32(16)) + check_rfactor_no_reset(T.int32(16), T.int32(16)) + check_rfactor_no_reset_multi_reduction(T.int32(16), T.int32(16)) # This tests whether algorithm and intrinsics expressions are simplified diff --git a/tests/python/te/test_te_tag.py b/tests/python/te/test_te_tag.py index 6e88a12614cf..a4b76e7d6736 100644 --- a/tests/python/te/test_te_tag.py +++ b/tests/python/te/test_te_tag.py @@ -57,12 +57,12 @@ def test_with(): assert C.op.tag == "gemm" assert "hello" in C.op.attrs assert "xx" not in C.op.attrs - assert C.op.attrs["hello"].value == 1 + assert C.op.attrs["hello"] == 1 CC = tvm.ir.load_json(tvm.ir.save_json(C)) - assert CC.op.attrs["hello"].value == 1 - assert CC.op.attrs["arr"][0].value == 10 - # str format happened to be json compatible - assert json.loads(str(CC.op.attrs))["arr"][1] == 12 + assert CC.op.attrs["hello"] == 1 + assert len(CC.op.attrs["arr"]) == 2 + assert CC.op.attrs["arr"][0] == 10 + assert CC.op.attrs["arr"][1] == 12 def test_decorator(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index e94a4f09ec56..0e610cc1659b 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -122,7 +122,7 @@ def test_lower_build_tir_func(): def test_lower_build_tir_module(): func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", True) + func = func.with_attr("tir.noalias", T.bool(True)) ir_mod = IRModule({"main": func}) # check lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index b4b773197b14..d706e65d8186 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing from tvm import te from tvm.tir import Buffer +from tvm.script import tir as T + import numpy as np +import pytest def test_buffer(): @@ -78,9 +81,9 @@ def test_buffer_access_ptr_extent(): # Test extent from input params aptr = Ab.access_ptr("rw", extent=200) - tvm.ir.assert_structural_equal(aptr.args[3], 200) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(200)) aptr = Ab.access_ptr("rw", offset=100, extent=100) - tvm.ir.assert_structural_equal(aptr.args[3], 100) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(100)) def test_buffer_vload(): @@ -88,7 +91,7 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.ir.assert_structural_equal(load.indices, [2, 3]) + tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)]) def test_buffer_offset_of(): @@ -259,7 +262,7 @@ def test_buffer_flatten(): buf = tvm.tir.decl_buffer([16, 32]) flat = buf.get_flattened_buffer() assert buf.data.same_as(flat.data) - tvm.ir.assert_structural_equal(flat.shape, [16 * 32]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(16 * 32)]) def test_buffer_flatten_preserves_identity(): @@ -273,8 +276,8 @@ def test_buffer_flatten_uses_axis_separators(): """Flattening to N-d physical buffers uses the axis separators""" buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) flat = buf.get_flattened_buffer() - tvm.ir.assert_structural_equal(flat.axis_separators, [1]) - tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) + tvm.ir.assert_structural_equal(flat.axis_separators, [T.int32(1)]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(4 * 16), T.int32(32)]) def test_invalid_axis_separators_raises_exception(): diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tir-base/test_tir_index_map.py index e893ed897d65..3ddbd2f69f59 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tir-base/test_tir_index_map.py @@ -22,6 +22,7 @@ from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm.script import tir as T def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -37,28 +38,22 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_indices([0]), [0, 0]) - assert_structural_equal(index_map.map_indices([3]), [0, 3]) - assert_structural_equal(index_map.map_indices([4]), [1, 0]) - assert_structural_equal(index_map.map_indices([42]), [10, 2]) - assert_structural_equal( - index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")] - ) + assert_structural_equal(index_map.map_indices([0]), [T.int32(0), T.int32(0)]) + assert_structural_equal(index_map.map_indices([3]), [T.int32(0), T.int32(3)]) + assert_structural_equal(index_map.map_indices([4]), [T.int32(1), T.int32(0)]) + assert_structural_equal(index_map.map_indices([42]), [T.int32(10), T.int32(2)]) + assert_structural_equal(index_map.map_indices([T.int64(42)]), [T.int64(10), T.int64(2)]) def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_shape([4]), [1, 4]) - assert_structural_equal(index_map.map_shape([16]), [4, 4]) + assert_structural_equal(index_map.map_shape([4]), [T.int32(1), T.int32(4)]) + assert_structural_equal(index_map.map_shape([16]), [T.int32(4), T.int32(4)]) - assert_structural_equal(index_map.map_shape([14]), [4, 4]) - assert_structural_equal( - index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")] - ) - assert_structural_equal( - index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")] - ) + assert_structural_equal(index_map.map_shape([14]), [T.int32(4), T.int32(4)]) + assert_structural_equal(index_map.map_shape([T.int64(16)]), [T.int64(4), T.int64(4)]) + assert_structural_equal(index_map.map_shape([T.int64(14)]), [T.int64(4), T.int64(4)]) def test_inverse(): @@ -82,28 +77,28 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[16], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(False), ), "right_padding": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 0, j < 1), ), "left_and_right_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[14], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.Or( tvm.tir.And(i == 0, j < 1), tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), @@ -113,7 +108,7 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[dynamic_N], - post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, 4], + post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, T.int32(4)], padding=lambda i, j: tvm.tir.And( dynamic_N % (-4) != 0, tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), @@ -127,10 +122,10 @@ def test_nonbijective_inverse_gives_error(): ], pre_shape=[14, 31], post_shape=[ - 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 - 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 - 4, # Range of iter%4 - 8, # Range of iter%8 + T.int32(4), # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + T.int32(5), # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + T.int32(4), # Range of iter%4 + T.int32(8), # Range of iter%8 ], padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( tvm.tir.Or( @@ -147,35 +142,35 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 32, (i // 4) % 8, i % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_right_padding_transpose": dict( forward=lambda i: [(i // 4) % 8, i // 32, i % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_left_padding": dict( forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "multiple_left_padding_with_transpose": dict( forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "outer_loop_extent_one": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [i * 4 + j], pre_shape=[3], - post_shape=[1, 4], + post_shape=[T.int32(1), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(3) == j, ), } diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index eeedae1f127c..29efd95280be 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -32,7 +32,7 @@ def test_te_const(): assert isinstance(x, tvm.tir.IntImm) -def test_scalar_dtype_inference(): +def test_tir_const_dtype_inference(): for data in [ True, bool(1), @@ -49,28 +49,11 @@ def test_scalar_dtype_inference(): np.float64(1), ]: assert tvm.tir.const(data).dtype == str(np.array(data).dtype) + + assert tvm.tir.const(True).dtype == "bool" assert tvm.tir.const(1).dtype == "int32" assert tvm.tir.const(1.0).dtype == "float32" - for data in [ - True, - bool(1), - np.uint8(1), - np.uint16(1), - np.uint32(1), - np.uint64(1), - np.int8(1), - np.int16(1), - np.int32(1), - np.int64(1), - np.float16(1), - np.float32(1), - np.float64(1), - ]: - assert tvm.runtime.convert(data).dtype == str(np.array(data).dtype) - assert tvm.runtime.convert(1).dtype == "int32" - assert tvm.runtime.convert(1.0).dtype == "float32" - def test_make(): x = tvm.tir.const(1, "int32") @@ -133,7 +116,7 @@ def test_attr(): assert stmt.node == y a = tvm.runtime.convert(1) - assert a.value == 1 + assert a == 1 try: a.no_field assert False @@ -350,7 +333,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) - assert f2.attrs["calling_conv"].value == 1 + assert f2.attrs["calling_conv"] == 1 assert not func.attrs diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index c2f3f89e6e12..8ae576e9b922 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -146,7 +146,7 @@ def test_sample_categorical_serialize(): decisions.append(rv) new_sch = verify_trace_roundtrip(sch, mod=elementwise) for i, new_inst in enumerate(new_sch.trace.insts): - assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst]] def test_sample_perfect_tile_power_of_two(): diff --git a/tests/python/tir-schedule/test_tir_schedule_state.py b/tests/python/tir-schedule/test_tir_schedule_state.py index 74880e5a42d9..c023b9dbc59d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state.py +++ b/tests/python/tir-schedule/test_tir_schedule_state.py @@ -155,10 +155,10 @@ def test_replace_direct_write0(): old_hash = s.mod["main"].__hash__() sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) - # There is no other reference so the AST node can be written directly - assert old_hash == s.mod["main"].__hash__() # Check the replaced part is equal to the target tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) + # There is no other reference so the AST node can be written directly + assert old_hash == s.mod["main"].__hash__() # The target reuse the stmt of the sref, so the sref won't be None assert sref.stmt is not None diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py index d5d5e0634ef6..cb7151f875e3 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py @@ -1029,38 +1029,45 @@ class TestTileAwareCompaction(BaseCompactTest): # it is not an opaque block case intentionally is_lower_order_free = False - @T.prim_func - def before( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ): - for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - A_local = T.decl_buffer((26, 128), scope="local") - B_local = T.decl_buffer((128, 26), scope="local") - C_local = T.decl_buffer((26, 26), scope="local") - for ax0, ax1 in T.grid(26, 128): - if i_0 * 26 + ax0 < 128: - A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] - for ax0, ax1 in T.grid(128, 26): - if j_0 * 26 + ax1 < 128: - B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] - for i_1, j_1, k in T.grid(26, 26, 128): - if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: - if k == 0: - C_local[i_1, j_1] = T.float32(0) - C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] - for ax0, ax1 in T.grid(26, 26): - if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: - C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] - - # Get partitioned workload to compact - before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod) - before_mod = tvm.tir.transform.LoopPartition()(before_mod) - before = before_mod["main"] + @property + def before(self): + @T.prim_func + def main( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + A_local = T.decl_buffer((26, 128), scope="local") + B_local = T.decl_buffer((128, 26), scope="local") + C_local = T.decl_buffer((26, 26), scope="local") + for ax0, ax1 in T.grid(26, 128): + if i_0 * 26 + ax0 < 128: + A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] + for ax0, ax1 in T.grid(128, 26): + if j_0 * 26 + ax1 < 128: + B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] + for i_1, j_1, k in T.grid(26, 26, 128): + if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: + if k == 0: + C_local[i_1, j_1] = T.float32(0) + C_local[i_1, j_1] = ( + C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] + ) + for ax0, ax1 in T.grid(26, 26): + if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: + C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] + + # Get partitioned workload to compact + mod = tvm.IRModule.from_expr(main) + with tvm.transform.PassContext( + config={"tir.LoopPartition": {"partition_const_loop": True}} + ): + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + + return mod["main"] @T.prim_func def expected( diff --git a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py index 9f61b5a3920a..3078572bb508 100644 --- a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing -from tvm import te +from tvm import te, tir + +import pytest import numpy as np @@ -184,7 +186,7 @@ def collect_branch_stmt(x): if isinstance(x, tvm.tir.IfThenElse): branch_collector.append(x) - n = 21 + n = tir.const(21) A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 23a51a0817df..0b43db56f300 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -394,5 +394,144 @@ def func_without_arg( tvm.ir.assert_structural_equal(Expected, After) +def test_int_parameter(): + """Boolean may be passed to functions accepting int + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts an integer argument, the caller may call it with a boolean + value. + + This also provides backwards compatibility for functions that were + defined as accepting an integer, but are called with a boolean + argument. Prior to PackedFunc interface supporting boolean + arguments directly, the argument would be converted from boolean + to integer to be stored in a TVMValue. After adding support for + boolean arguments, this usage should not cause an error. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg > 0: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" + arg: T.int32 = T.if_then_else( + arg_code == 0, + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg > 0: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_bool_parameter(): + """An integer may be passed to a function acccepting Boolean + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts a boolean argument, the caller may call it with an integer + value. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" + arg: T.bool = T.if_then_else( + arg_code == 15, + T.tvm_struct_get(args, 0, 12, "bool"), + T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 4b71eb825414..68149e7d64bb 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -937,8 +937,8 @@ def test_vulkan_smem_reuse(): "kind": "vulkan", "max_num_threads": 256, "max_threads_per_block": 256, - "supports_float32": T.bool(True), - "supports_int32": T.bool(True), + "supports_float32": True, + "supports_int32": True, "tag": "", "thread_warp_size": 1, } diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index 279785fdca51..d8212d38854c 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -332,26 +332,35 @@ def convert_slice_to_bufferload() -> None: check_error(convert_slice_to_bufferload, 6) -def test_tvm_exception_catch(): +def test_tvm_exception_catch_from_special_stmt(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) + check_error(special_stmt_except, 2) + + +def test_tvm_exception_catch_from_scope_handler(): def scope_handler_except() -> None: for i in T.serial("1", "1"): # error T.evaluate(1) + check_error(scope_handler_except, 2) + + +def test_tvm_exception_catch_from_bare_intrin(): def intrin_except_unassign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.evaluate(A) # error + check_error(intrin_except_unassign, 3) + + +def test_tvm_exception_catch_from_assigned_intrin(): def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") A[0, 0] = A[A] # error - check_error(special_stmt_except, 2) - check_error(scope_handler_except, 2) - check_error(intrin_except_unassign, 3) check_error(intrin_except_assign, 3) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 8364e65a4178..b7ba57fa9387 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -230,7 +230,7 @@ def test_buffer_store(): obj, """ A = T.Buffer((128, 128), "float16") -A[128, 128] = A[128, 128] + T.float16(1) +A[128, 128] = A[128, 128] + T.float16(1.0) """, ) @@ -259,7 +259,7 @@ def test_let_stmt(): _assert_print( obj, """ -with T.LetStmt(T.float32(10)) as v: +with T.LetStmt(T.float32(10.0)) as v: T.evaluate(0) """, ) @@ -672,7 +672,7 @@ def test_call(): _assert_print( obj, """ -T.atan(T.float32(1)) +T.atan(T.float32(1.0)) """, ) @@ -682,7 +682,7 @@ def test_comm_reducer(): _assert_print( obj, """ -T.comm_reducer(lambda x, y: x + y, [T.float32(0)]) +T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]) """, ) @@ -712,7 +712,7 @@ def test_float_imm(): _assert_print( obj, """ -T.float16(1) +T.float16(1.0) """, ) @@ -942,7 +942,7 @@ def func(): @T.prim_func def func(): - T.evaluate(T.{dtype}(0)) + T.evaluate(T.{dtype}(0.0)) """ func = get_func(dtype) _assert_print(func, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f81a80de6d61..b44ff5ad7241 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) def block_elements(): @@ -3981,6 +3981,32 @@ def func() -> T.int32: return func +def func_attr_with_list(): + @T.prim_func + def func( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + D: T.Buffer((128, 128), "float32"), + ) -> None: + T.func_attr( + {"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [T.int32(1)]} + ) + C = T.alloc_buffer([128, 128], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("C"): + x, y, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C[x, y] = T.float32(0) + C[x, y] = C[x, y] + A[x, k] * B[y, k] + for i0, i1 in T.grid(128, 128): + with T.block("D"): + T.block_attr({"layout_free_placeholders": [C]}) + x, y = T.axis.remap("SS", [i0, i1]) + D[x, y] = C[x, y] + T.float32(1) + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4198,6 +4224,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero, return_zero_private, return_zero_private_with_attr, + func_attr_with_list, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9bc9800c1cb8..ae83a9d66392 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -19,6 +19,7 @@ import tvm from tvm import te from tvm.topi import utils +from tvm.script import tir as T from .environment import get_env @@ -1046,19 +1047,19 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) - tvm.ir.assert_structural_equal(src_coeff[-2], 1) - tvm.ir.assert_structural_equal(dst_coeff[-2], 1) + tvm.ir.assert_structural_equal(src_coeff[-2], T.int32(1)) + tvm.ir.assert_structural_equal(dst_coeff[-2], T.int32(1)) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - tvm.ir.assert_structural_equal(src_coeff[-3], env.BLOCK_OUT) - tvm.ir.assert_structural_equal(dst_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(src_coeff[-3], T.int32(env.BLOCK_OUT)) + tvm.ir.assert_structural_equal(dst_coeff[-3], T.int32(env.BLOCK_OUT)) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1] From fb16d9487d062353b1fed3b14729e9282da2b875 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 14 Aug 2024 18:25:09 +0530 Subject: [PATCH 061/202] [CODEGEN][OPENCL] Fix opencl codegen for few ops (#17273) * Compiler pass config to choose target clml support version Partition pass should shoose off loading ops based on target support this config enables choosing target version on python api aswell as tvmc. * Update clml.py * Fix opencl codegen for few ops Fixed the opencl codegen for few operators - 1. Atomic add for float - opencl doesn't have support float atomic add, Enabled work-around for this operation with atomic_cmpexch() 2. fmodf - Opencl only support fmod for all floating point 3. nearbyint - Opencl doesn't have this function and henced replaced with roud function. * Update test_relay_ops.py * Update codegen_opencl.cc * Update codegen_opencl.cc * Revert "Compiler pass config to choose target clml support version" This reverts commit bc955b02c436cdab7e397a2f1e66d828861da6e8. * Revert "Update clml.py" This reverts commit 4ff98a82dc463628f673292631df518e6831fd4e. --------- Co-authored-by: Siva Co-authored-by: B, Siva Rama Krishna Reddy Co-authored-by: Vegiraju, Krishna Raju --- python/tvm/topi/cuda/nms.py | 4 +- src/target/source/codegen_opencl.cc | 52 ++++++++++++- src/target/source/codegen_opencl.h | 1 + .../relay/opencl_texture/test_relay_ops.py | 73 +++++++++++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/python/relay/opencl_texture/test_relay_ops.py diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index e402c5888978..f258bffc3e8f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -50,7 +50,9 @@ def cuda_atomic_add_rule(op): def opencl_atomic_add_rule(op): if op.dtype == "int32": return tvm.tir.call_pure_extern("int32", "atomic_add", op.args[0], op.args[1]) - raise RuntimeError("only support int32") + elif op.dtype == "float32": + return tvm.tir.call_pure_extern("float32", "atomic_add", op.args[0], op.args[1]) + raise RuntimeError("only support int32, float32") register_intrin_lowering("tir.atomic_add", target="cuda", f=cuda_atomic_add_rule, level=99) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index f17a452d5c28..5933c9582cec 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -129,6 +129,16 @@ std::string CodeGenOpenCL::Finish() { if (enable_atomics_) { decl_stream << "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" "#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n\n"; + decl_stream << "__inline float atomic_add_float_emu(volatile __global float* sum, const float " + "toAdd) {\n" + "float next_value = 0;" + "float prev_value = 0;" + "do {\n" + "prev_value =*(sum);\n" + "next_value =prev_value + toAdd;\n" + "} while(atomic_cmpxchg((volatile global int *)(sum), *((int*)&prev_value), " + "*((int*)&next_value)) != *((int*)&prev_value));\n" + "return next_value;\n}\n"; } // Enable OpenCL 1.2 sampler-less texture reads, but utilize @@ -458,13 +468,21 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args.back(), os); os << "]"; } - } else if (op->op.same_as(builtin_call_extern_)) { + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { auto func = Downcast(op->args[0]); // Enable atomics extension if used. - if (func->value == "atomic_add") { + if (func->value == "atomic_add" && op->dtype.is_float()) { enable_atomics_ = true; + this->PrintCallExtern(GetType(GetRef(op)), "atomic_add_float_emu", op->args, true, + os); + } else if (func->value == "nearbyint") { + this->PrintCallExtern(GetType(GetRef(op)), "round", op->args, true, os); + } else { + if (func->value == "atomic_add") { + enable_atomics_ = true; + } + CodeGenC::VisitExpr_(op, os); } - CodeGenC::VisitExpr_(op, os); } else { CodeGenC::VisitExpr_(op, os); } @@ -534,6 +552,34 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { PrintBinaryExpr(op, "max", os, this); } +void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) + std::string opstr; + if (op->dtype.is_int() || op->dtype.is_uint()) { + opstr = "%"; + } else { + ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " + << op->dtype; + opstr = "fmod"; + } + if (op->dtype.lanes() == 1) { + if (isalpha(opstr.c_str()[0])) { + os << opstr.c_str() << '('; + this->PrintExpr(op->a, os); + os << ", "; + this->PrintExpr(op->b, os); + os << ')'; + } else { + os << '('; + this->PrintExpr(op->a, os); + os << ' ' << opstr.c_str() << ' '; + this->PrintExpr(op->b, os); + os << ')'; + } + } else { + this->PrintVecBinaryOp(opstr.c_str(), op->dtype, op->a, op->b, os); + } +} + void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { std::ostringstream oss; os << "("; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 8b365f85d6e6..e668f75b2ec2 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -74,6 +74,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const AndNode* op, std::ostream& os) final; void VisitExpr_(const OrNode* op, std::ostream& os) final; void VisitExpr_(const SelectNode* op, std::ostream& os) final; + void VisitExpr_(const ModNode* op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/tests/python/relay/opencl_texture/test_relay_ops.py b/tests/python/relay/opencl_texture/test_relay_ops.py new file mode 100644 index 000000000000..686a9a9b9e89 --- /dev/null +++ b/tests/python/relay/opencl_texture/test_relay_ops.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm + + +executor_type = tvm.testing.parameter("ge", "vm") +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_mod(remote, target, executor_type, dtype): + # NCHW + input_shape = (1, 25, 38, 64) + A = relay.var("data", shape=input_shape, dtype=dtype) + scale = relay.const(2.0, dtype=dtype) + op = relay.mod(A, scale) + mod = relay.Function([A], op) + + if executor_type == "ge": + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_scatter_nd_add(remote, target, executor_type, dtype): + # NCHW + + A = relay.var("data", shape=(6, 30, 30, 256), dtype=dtype) + indices = relay.const(tvm.nd.array(np.random.randint(0, 1, (2, 6, 30, 30))), dtype="int64") + update = relay.const( + tvm.nd.array(np.random.uniform(-1, 1, size=(50, 50, 256)).astype(dtype)), dtype=dtype + ) + op = relay.scatter_nd(update, indices, A, mode="add") + mod = relay.Function([A], op) + shape_dict = { + "data": (6, 30, 30, 256), + } + dtype_dict = { + "data": dtype, + } + + if executor_type == "ge": + build_run_compare(remote, mod, {}, shape_dict, dtype_dict, target) + else: + build_run_compare_vm(remote, mod, {}, shape_dict, dtype_dict, target) + + +if __name__ == "__main__": + tvm.testing.main() From 132daf6c959efe04cffa90234ef1688d82d193e3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 15 Aug 2024 09:52:37 -0700 Subject: [PATCH 062/202] [Disco] Fix double free of nccl communicator (#17275) --- src/runtime/disco/nccl/nccl_context.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index 730479b61ac0..b874da219fe4 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -129,6 +129,9 @@ struct CCLThreadLocalContext { void Clear() { if (group_comm) { NCCL_CALL(ncclCommDestroy(group_comm)); + if (global_comm == group_comm) { + global_comm = nullptr; + } group_comm = nullptr; } if (global_comm) { From 4a37f64167ce80552719cf9975c5ff8e4a053538 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 17 Aug 2024 10:22:28 -0700 Subject: [PATCH 063/202] [KVCache] Increase coalesce threshold (#17280) This PR changes the threshold of coalesce in kvcache for better performance. --- src/runtime/relax_vm/paged_kv_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index cf5de97202cc..6bf3dc7ce609 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1727,7 +1727,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - rope_mode_ == RoPEMode::kNormal); + static_cast(rope_mode_ == RoPEMode::kNormal)); // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { @@ -2202,7 +2202,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced; // Do not coalesce and use batch decode kernel when coalesce ratio is small. - bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 1.1; + bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 32; return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids, use_decode_kernel}; } From 517c420d7b89029638926f10bbe9bed27f23bb5f Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 19 Aug 2024 18:22:45 +0530 Subject: [PATCH 064/202] [TOPI][ADRENO] Add Group Conv2d texture schedule (#17274) * Added Support for Adreno Texture Based Group Convolution * Added Few Testcases and Fixed Compute * Limited Support for Group Convolution * Removed Dead Code, Fixed Minor Issues --------- Co-authored-by: Sanjay Shankar Krishnaa --- python/tvm/relay/op/strategy/adreno.py | 31 +- python/tvm/topi/adreno/__init__.py | 1 + python/tvm/topi/adreno/group_conv2d_nchw.py | 386 ++++++++++++++++++ .../test_group_conv2d_nchw_texture.py | 208 ++++++++++ 4 files changed, 625 insertions(+), 1 deletion(-) create mode 100644 python/tvm/topi/adreno/group_conv2d_nchw.py create mode 100644 tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index bacace9ad4f6..99e4d0a405f0 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -182,8 +182,37 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): + kernel_layout + ") - only support NCHW4c / OIHW4o and NHWC / HWOI layouts for conv2d" ) + elif (data_layout == "NCHW4c" or data_layout == "NCHW") and ( + kernel_layout == "OIHW" or kernel_layout == "OIHW4o" + ): + pad_in_chunks = (len(data.shape) == 5 and data.shape[1] % groups != 0) or ( + len(data.shape) == 4 and data.shape[1] % (groups * 4) != 0 + ) + pad_out_chunks = (len(kernel.shape) == 5 and kernel.shape[0] % groups != 0) or ( + len(kernel.shape) == 4 and kernel.shape[0] % (groups * 4) != 0 + ) + + if not (pad_in_chunks or pad_out_chunks): + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.group_conv2d_nchwc), + wrap_topi_schedule(topi.adreno.schedule_group_conv2d_nchwc), + name="group_conv2d_nchwc.image2d", + plevel=10, + ) + elif len(data.shape) == 4 and len(kernel.shape) == 4: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), + name="group_conv2d_nchw.cuda", + ) + else: + raise RuntimeError( + "General group convolution is not currently supported for NCHWc layouts" + ) else: - raise RuntimeError("General group convolution is not currently supported") + raise RuntimeError( + "General group convolution has limited support for NCHW(4c) layouts..." + ) return strategy diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py index cd42848b29b3..2c0ed20f1011 100644 --- a/python/tvm/topi/adreno/__init__.py +++ b/python/tvm/topi/adreno/__init__.py @@ -20,6 +20,7 @@ from .conv2d_nchw import * from .depthwise_conv2d_nchw import * from .conv2d_nhwc import * +from .group_conv2d_nchw import * from .depthwise_conv2d_nhwc import * from .pooling import * from .conv2d_alter_op import * diff --git a/python/tvm/topi/adreno/group_conv2d_nchw.py b/python/tvm/topi/adreno/group_conv2d_nchw.py new file mode 100644 index 000000000000..f1ab7fcf0e64 --- /dev/null +++ b/python/tvm/topi/adreno/group_conv2d_nchw.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return + +"""Group Conv2d NCHW Operator wt Schedule on Qualcomm Adreno GPU""" +import tvm +from tvm import te +from tvm import autotvm + +from ..utils import get_const_tuple, traverse_inline +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + expand_spatial_dimensions, + add_pad, + bind_data_copy, + get_default_conv2d_config, + get_texture_storage, +) + + +@autotvm.register_topi_schedule("group_conv2d_nchwc.image2d") +def schedule_group_conv2d_nchwc(cfg, outs): + """Create the schedule for group_conv2d_nchw""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "adreno_group_conv2d_latest_op": + schedule_group_conv2d_NCHWc_KCRSk(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("group_conv2d_nchwc.image2d") +def group_conv2d_nchwc(cfg, Input, Filter, stride, padding, dilation, out_dtype): + """ + Group Convolution Operator in NCHWc layout. + Algo: + 1. Convert into blocked format if we have 4d original tensor. + In case of AutoTVM we override the convert by just tensors since such conversion + will be absent for real blocked convolution, no sense to include into tuning + 2. Expand spatial dimensions to have width and height be dividable by factor 4 + This leads to slightly bigger amount of compute but allow utilize GPU much better + 3. Add paddings. This happens even if we do not need pad originaly. This is useful + due to work surrounding the gaps of texture annotation between Primary Functions + and limited support of textures in schedules. Later on this pad will be executed + separately and will produce texture + 4. 5d Convolution compute with accumulating into out_dtype + 5. Cast to the origin output data type + 6. For case of 4d convolution: convert of output from 5d to 4d + """ + + if out_dtype is None: + out_dtype = Input.dtype + + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + convert_from4d = False + if len(Input.shape) == 4: + batch, in_channels, in_height, in_width = Input.shape + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4) + + if autotvm.GLOBAL_SCOPE.in_tuning: + dshape = (batch, in_channel_chunks, in_height, in_width, in_channel_block) + Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder") + else: + Input = pack_input( + Input, + "NCHW", + batch, + in_channel_chunks, + in_channel_block, + in_channel_tail, + in_height, + in_width, + ) + else: + batch, in_channel_chunks, in_height, in_width, in_channel_block = Input.shape + in_channels = in_channel_chunks * in_channel_block + + if len(Filter.shape) == 4: + out_channels, in_filter_channels, kernel_h, kernel_w = Filter.shape + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) + + if autotvm.GLOBAL_SCOPE.in_tuning: + kshape = (out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block) + Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder") + else: + convert_from4d = True + Filter = pack_filter( + Filter, + "OIHW", + out_channel_chunks, + out_channel_block, + out_channel_tail, + in_filter_channels, + in_channel_chunks, + in_channel_block, + in_channel_tail, + kernel_h, + kernel_w, + ) + else: + out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block = Filter.shape + out_channels = out_channel_chunks * out_channel_block + + assert in_channels % in_filter_channels == 0 + groups = in_channels // in_filter_channels + + # Compute Constraints... + assert out_channel_chunks % groups == 0 + assert in_channel_chunks % groups == 0 + + out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions( + in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w + ) + + temp = add_pad( + Input, + "NCHW", + out_height_orig, + out_width_orig, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + padding, + stride_h, + stride_w, + ) + + in_group_channel_chunks = in_channel_chunks // groups + in_group_channel_block = in_channel_block + out_group_channel_chunks = out_channel_chunks // groups + rcc = te.reduce_axis((0, in_group_channel_chunks), name="rcc") + rcb = te.reduce_axis((0, in_group_channel_block), name="rcb") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + conv = te.compute( + (batch, out_channel_chunks, out_height, out_width, out_channel_block), + lambda nn, occ, yy, xx, obb: te.sum( + ( + temp[ + nn, + occ // out_group_channel_chunks * in_group_channel_chunks + rcc, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + rcb, + ] + * Filter[occ, rcc * in_group_channel_block + rcb, ry, rx, obb] + ).astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc_group", + ) + + if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning: + dummy_cast = te.compute( + (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block), + lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype), + tag="dummy_cast", + ) + return te.compute( + (batch, out_channels, out_height_orig, out_width_orig), + lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % out_channel_block], + tag="adreno_group_conv2d_latest_op", + ) + else: + return te.compute( + (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block), + lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype), + tag="adreno_group_conv2d_latest_op", + ) + + +def schedule_group_conv2d_NCHWc_KCRSk(cfg, s, output): + """ + Schedule optimized for batch size = 1 + + Algo: + 1. Split output axis to three parts: global work size, vthread, local worksize. + The limitations for tuning includes heuristics from some tuned networks to limit + search space and not pay much time for useles configurations. + 2. In case of 4d convolution schedule copying of the input (and filter) into + 5d tensors + 4. pad should be scheduled separately to create independent opencl kernel. If pad is + inlined into convolution, this gives 1.5x performance drop + 5. We are using cache_read for intermediate tensors to produce texture and guarantee + the best performance on the next stage. + The weights are managed through static texture planning mechanism and guarantied come + in texture memory scope. + Thus way we are calling cache_read only for data tensor + 6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize + for textures + For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion + of data type + 7. In case of 4d conv we need to schedule postops as well + """ + latest = s.outputs[0].output(0) + if len(latest.op.axis) == 4: + latest_blocked = dummy = output.op.input_tensors[0] + conv = dummy.op.input_tensors[0] + else: + conv = output.op.input_tensors[0] + latest_blocked = latest + + pad_data, kernel = s[conv].op.input_tensors + filter_pack_rt = bool( + isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag + ) + + if "pad_temp" in pad_data.op.name: + input_pad_temp = pad_data.op.input_tensors[0] + else: + input_pad_temp = pad_data + + input_pack_rt = bool( + isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in input_pad_temp.op.tag + ) + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + + if conv.shape[1] % 2 == 0: + min_threads_div = 2 + else: + min_threads_div = 1 + cfg.define_split( + "tile_fc", + fc, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 + and entity.size[2] >= min_threads_div + and entity.size[2] < 256, + ) + cfg.define_split( + "tile_y", + y, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + cfg.define_knob("unroll_explicit", [0, 1]) + cfg.multi_filter( + filter=lambda entity: ( # pylint: disable=chained-comparison + entity["tile_fc"].size[1] * entity["tile_y"].size[1] * entity["tile_x"].size[1] + ) + <= 24 + and 32 + <= (entity["tile_fc"].size[2] * entity["tile_y"].size[2] * entity["tile_x"].size[2]) + < 1024 + ) + if cfg.is_fallback: + get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], conv.shape[3]) + ##### space definition end ##### + + pad_data, kernel = s[conv].op.input_tensors + # There are several conditions that have to be handled: + # 1. If we are in the tuning, we always add cache read for data to main conv kernel + # to get texture in tuning opencl kernel + # 2. If we are repacking input in runtime, we should always explicit schedule this one more + # stage of data copy from 4d to 5d (referred as pack_data). + # 3. If we have pad (independently if we have runtime repack or not) we should inline it in the + # cache_read("texture") + if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt: + if autotvm.GLOBAL_SCOPE.in_tuning: + if "pad_temp" in pad_data.op.name: + s[pad_data].compute_inline() + else: + if "pad_temp" in pad_data.op.name: + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + s[pad_data].compute_inline() + else: + pack_data = pad_data + bind_data_copy(s[pack_data]) + + AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) + bind_data_copy(s[AT]) + elif "pad_temp" in pad_data.op.name: + s[pad_data].compute_inline() + # create cache stage + AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) + bind_data_copy(s[AT]) + + if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt: + if not autotvm.GLOBAL_SCOPE.in_tuning: + bind_data_copy(s[kernel]) + if kernel.shape[2] == 1 and kernel.shape[3] == 1: + WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv]) + bind_data_copy(s[WT]) + + s[conv].set_scope("local") + if latest_blocked == latest and output != latest: + s[output].compute_inline() + + # tile and bind spatial axes + n, fc, y, x, fb = s[latest_blocked].op.axis + + kernel_scope, n = s[latest_blocked].split(n, nparts=1) + + bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc) + by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y) + bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x) + + bf = s[latest_blocked].fuse(n, bf) + s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z")) + s[latest_blocked].bind(by, te.thread_axis("blockIdx.y")) + s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x")) + s[latest_blocked].bind(vf, te.thread_axis("vthread")) + s[latest_blocked].bind(vy, te.thread_axis("vthread")) + s[latest_blocked].bind(vx, te.thread_axis("vthread")) + s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z")) + s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y")) + s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x")) + s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb) + s[latest_blocked].vectorize(fb) + + s[conv].compute_at(s[latest_blocked], tx) + + # tile reduction axes + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + + rco, rci = cfg["tile_rcc"].apply(s, conv, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, conv, ry) + rxo, rxi = cfg["tile_rx"].apply(s, conv, rx) + s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[conv].unroll(rcb) + s[conv].vectorize(fb) + + # unroll + s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + if latest_blocked != latest: + s[latest].compute_root() + bind_data_copy(s[latest], 1) + if latest != output: + s[output].compute_inline() + + N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape) + _, IC, KH, KW, _ = get_const_tuple(kernel.shape) + ICKHKW = IC * KH * KW + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) diff --git a/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py b/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py new file mode 100644 index 000000000000..bd05610e92b7 --- /dev/null +++ b/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from utils.adreno_utils import build_run_compare, build_run_compare_vm + +executor_type = tvm.testing.parameter("ge", "vm") +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_group_conv2d_nchwc_adreno_encoder1(remote, target, executor_type, dtype): + input_shape = (1, 512, 56, 100) + filter_shape = (512, 64, 3, 3) + bias_shape = (1, 512, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=512, + groups=8, + dilation=1, + kernel_size=(3, 3), + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(1) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + if executor_type == "ge": + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_group_conv2d_nchwc_adreno_encoder2(remote, target, executor_type, dtype): + input_shape = (1, 1024, 56, 100) + filter_shape = (512, 128, 3, 3) + bias_shape = (1, 512, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[3, 3, 3, 3], + strides=[2, 2], + out_dtype=dtype, + channels=512, + groups=8, + dilation=2, + kernel_size=(3, 3), + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(1) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + if executor_type == "ge": + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_group_conv2d_nchwc_adreno_nontrivial(remote, target, executor_type, dtype): + input_shape = (1, 56, 56, 100) + filter_shape = (112, 8, 7, 3) + bias_shape = (1, 112, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[3, 3, 3, 3], + strides=[1, 2], + out_dtype=dtype, + channels=112, + groups=7, + dilation=2, + kernel_size=(7, 3), + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(1) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + if executor_type == "ge": + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_group_conv2d_nchwc_default(remote, target, executor_type, dtype): + input_shape = (1, 49, 56, 100) + filter_shape = (343, 7, 3, 3) + bias_shape = (1, 343, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + # C = relay.nn.relu(A) + conv = relay.nn.conv2d( + A, + B, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=343, + groups=7, + dilation=1, + kernel_size=(3, 3), + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(1) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + if executor_type == "ge": + build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + else: + build_run_compare_vm(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + tvm.testing.main() From 6bcec1d6c358268b12da733d995f61bb7384b0ac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Aug 2024 08:29:59 -0500 Subject: [PATCH 065/202] [CI] Resolve CI compilation failures on MacOSX (#17271) * Debug, list configs in base conda environment * Add the "auto-update-conda: true" flag for miniconda setup It looks like the base environment provides `conda==24.5.0`, but the `tvm-build` environment only provides `conda==23.9.0`, and the error in `cargo build` is triggered from within the `tvm-build` environment. Seeing if it just needs to be allowed to update to a newer `conda` version. * Attempt bumping the required conda version The `conda-build` package specifies compatibility with `conda >= 23.7`, but the `libmamba` requirement requirement isn't provided until `23.10`. Possibly an incompatibility, where the default solver is decided based on the base environment's `conda` version, but the availability is based on the `tvm-build` environment. * Try adding "conda-solver: classic" Since libmamba isn't available inside the generated environment * Exit on cmake failure in Windows build * Exit on first error for Windows conda build From what I can tell, batch scripts do not have an equivalent to `set -e`, so this needs to be added to every command in the batch scripts. --- .github/actions/setup/action.yml | 4 ++++ conda/build_win.bat | 4 +++- conda/recipe/bld.bat | 2 +- conda/recipe/install_libtvm.bat | 8 +++++--- conda/recipe/install_tvm_python.bat | 4 ++-- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 40ddf4f90678..6fd81c1d6903 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -15,6 +15,7 @@ runs: channel-priority: strict environment-file: conda/build-environment.yaml auto-activate-base: false + conda-solver: classic use-only-tar-bz2: true python-version: 3.9 condarc-file: conda/condarc @@ -25,6 +26,7 @@ runs: channel-priority: strict environment-file: conda/build-environment.yaml auto-activate-base: false + conda-solver: classic use-only-tar-bz2: true python-version: 3.9 condarc-file: conda/condarc @@ -33,3 +35,5 @@ runs: run: | conda info conda list + conda info --envs + conda list --name base diff --git a/conda/build_win.bat b/conda/build_win.bat index 59d0d07340c7..e37a06ce7c05 100644 --- a/conda/build_win.bat +++ b/conda/build_win.bat @@ -15,4 +15,6 @@ :: specific language governing permissions and limitations :: under the License. -conda build --output-folder=conda/pkg conda/recipe +echo on + +conda build --output-folder=conda/pkg conda/recipe || exit /b diff --git a/conda/recipe/bld.bat b/conda/recipe/bld.bat index f8988b135793..561dcff87802 100644 --- a/conda/recipe/bld.bat +++ b/conda/recipe/bld.bat @@ -32,7 +32,7 @@ cmake ^ -DUSE_RANDOM=ON ^ -DUSE_PROFILER=ON ^ -DINSTALL_DEV=ON ^ - %SRC_DIR% + %SRC_DIR% || exit /b cd .. :: defer build to install stage to avoid rebuild. diff --git a/conda/recipe/install_libtvm.bat b/conda/recipe/install_libtvm.bat index f423c521f84e..c56f83bfaaef 100644 --- a/conda/recipe/install_libtvm.bat +++ b/conda/recipe/install_libtvm.bat @@ -15,8 +15,10 @@ :: specific language governing permissions and limitations :: under the License. -cmake --build build --config Release --target install +echo on + +cmake --build build --config Release --target install || exit /b :: Copy files into library bin so that they can be found -cp %LIBRARY_LIB%\tvm.dll %LIBRARY_BIN%\tvm.dll -cp %LIBRARY_LIB%\tvm_runtime.dll %LIBRARY_BIN%\tvm_runtime.dll +cp %LIBRARY_LIB%\tvm.dll %LIBRARY_BIN%\tvm.dll || exit /b +cp %LIBRARY_LIB%\tvm_runtime.dll %LIBRARY_BIN%\tvm_runtime.dll || exit /b diff --git a/conda/recipe/install_tvm_python.bat b/conda/recipe/install_tvm_python.bat index 96187468c2b2..07c0465b8443 100644 --- a/conda/recipe/install_tvm_python.bat +++ b/conda/recipe/install_tvm_python.bat @@ -16,5 +16,5 @@ :: under the License. echo on -cd %SRC_DIR%\python -%PYTHON% setup.py install --single-version-externally-managed --record=%SRC_DIR%\record.txt +cd %SRC_DIR%\python || exit /b +%PYTHON% setup.py install --single-version-externally-managed --record=%SRC_DIR%\record.txt || exit /b From 6f4ac2312b9bbcbfb465ead0de410ab7dd1494a4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 19 Aug 2024 22:31:50 +0900 Subject: [PATCH 066/202] [Relay][Pytorch] Add support for `aten::tile` (#17277) * add test for torch.tile * add support for `aten::tile` --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++ tests/python/frontend/pytorch/test_forward.py | 24 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1f78d7739007..0d93ff987c6e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4022,6 +4022,16 @@ def scaled_dot_product_attention(self, inputs, input_types): attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) return attn_weight + def tile(self, inputs, input_types): + data = inputs[0] + reps = [] + for r in inputs[1]: + if isinstance(r, int): + reps.append(r) + else: + reps.append(int(_infer_value(r, {}).numpy())) + return _op.tile(data, reps) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4302,6 +4312,7 @@ def create_convert_map(self): "aten::swapaxes": self.transpose, "aten::linalg_vector_norm": self.linalg_vector_norm, "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, + "aten::tile": self.tile, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a273af8fb89d..9f8fac93061c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5658,6 +5658,30 @@ def forward(self, x): verify_model(ParamListModel().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_tile(): + """test_forward_repeat""" + torch.set_grad_enabled(False) + input_shape = [1, 3] + + class Tile1(Module): + def forward(self, *args): + return args[0].tile(1, 1) + + class Tile2(Module): + def forward(self, *args): + return args[0].tile(4, 2) + + class Tile3(Module): + def forward(self, *args): + return args[0].tile(4, 2, 1) + + input_data = torch.rand(input_shape).float() + verify_model(Tile1().float().eval(), input_data=input_data) + verify_model(Tile2().float().eval(), input_data=input_data) + verify_model(Tile3().float().eval(), input_data=input_data) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged.""" From 1ca9833db2289923c4a557385be05307afb2e9ca Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Aug 2024 08:33:54 -0500 Subject: [PATCH 067/202] [IR] Handle NaN in StructuralEqual and StructuralHash (#17249) * [IR] Handle NaN in StructuralEqual and StructuralHash Prior to this commit, `NaN` values did not have any special handling in either `StructuralEqual` or `StructuralHash`. `StructuralEqual` checked whether the LHS and RHS were within some tolerance of each other. If the LHS and RHS are both `NaN`, this would evaluate to false. The updated `StructuralEqual` now checks for this case, and returns true if both sides are `NaN`. `StructuralHash` used the bit-pattern of a floating-point number to compute the hash. A `NaN` value may have any non-zero value in its mantissa, and so this could produce distinct hashes for ASTs that differ only by the choice of non-zero value. The updated `StructuralHash` uses the same `std::numeric_limits #include +#include #include namespace tvm { @@ -38,11 +39,21 @@ namespace tvm { class BaseValueEqual { public: bool operator()(const double& lhs, const double& rhs) const { - // fuzzy float pt comparison - constexpr double atol = 1e-9; - if (lhs == rhs) return true; - double diff = lhs - rhs; - return diff > -atol && diff < atol; + if (std::isnan(lhs) && std::isnan(rhs)) { + // IEEE floats do not compare as equivalent to each other. + // However, for the purpose of comparing IR representation, two + // NaN values are equivalent. + return true; + } else if (std::isnan(lhs) || std::isnan(rhs)) { + return false; + } else if (lhs == rhs) { + return true; + } else { + // fuzzy float pt comparison + constexpr double atol = 1e-9; + double diff = lhs - rhs; + return diff > -atol && diff < atol; + } } bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 774021ad1564..553f284b8c5a 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -27,7 +27,9 @@ #include #include +#include #include +#include #include namespace tvm { @@ -52,7 +54,16 @@ class BaseValueHash { public: uint64_t operator()(const float& key) const { return Reinterpret(key); } - uint64_t operator()(const double& key) const { return Reinterpret(key); } + uint64_t operator()(const double& key) const { + if (std::isnan(key)) { + // The IEEE format defines more than one bit-pattern that + // represents NaN. For the purpose of comparing IR + // representations, all NaN values are considered equivalent. + return Reinterpret(std::numeric_limits::quiet_NaN()); + } else { + return Reinterpret(key); + } + } uint64_t operator()(const int64_t& key) const { return Reinterpret(key); } uint64_t operator()(const uint64_t& key) const { return key; } uint64_t operator()(const int& key) const { return Reinterpret(key); } diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index eca78d649b85..32099cecf4b2 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -419,5 +419,48 @@ def func(A: T.Buffer(1, "int32")): assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0] +def test_nan_values_are_equivalent(): + """Structural equality treats two NaN values as equivalent. + + By IEEE, a check of `NaN == NaN` returns false, as does + `abs(NaN - NaN) < tolerance`. However, for the purpose of + comparing IR representations, both NaN values are equivalent. + + """ + + @T.prim_func(private=True) + def func_1(): + return T.float32("nan") + + @T.prim_func(private=True) + def func_2(): + return T.float32("nan") + + tvm.ir.assert_structural_equal(func_1, func_2) + assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2) + + +def test_all_nan_values_are_equivalent(): + """Structural equality treats two NaN values as equivalent. + + IEEE defines NaN as any value that has all exponent bits set, + and has a non-zero mantissa. For the purposes of comparing IR + representations, all NaN values are considered equivalent. + + """ + + # A NaN with the first payload bit set. + nan_all_zeros = np.int32(0x7FC00000).view("float32") + + # A NaN with the last payload bit set. + nan_with_payload = np.int32(0x7F800001).view("float32") + + float_1 = T.float32(nan_all_zeros) + float_2 = T.float32(nan_with_payload) + + tvm.ir.assert_structural_equal(float_1, float_2) + assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2) + + if __name__ == "__main__": tvm.testing.main() From 7bea15f162ceb3f38809212eec5d711929709620 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 21 Aug 2024 00:53:53 +0530 Subject: [PATCH 068/202] [WINDOWS] Compiler options for non x86 targets (#17260) --- python/tvm/contrib/cc.py | 5 ++++- python/tvm/dlight/gpu/gemv.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 59b57e08ba49..110f80db6186 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -372,8 +372,11 @@ def _linux_compile( def _windows_compile(output, objects, options, cwd=None, ccache_env=None): - cmd = ["clang"] + compiler = os.getenv("TVM_WIN_CC", default="clang") + win_target = os.getenv("TVM_WIN_TARGET", default="x86_64") + cmd = [compiler] cmd += ["-O2"] + cmd += ["--target=" + win_target] if output.endswith(".so") or output.endswith(".dll"): cmd += ["-shared"] diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 2bcb8563a294..cff234140e50 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -11,7 +11,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the +# KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """A rule for GEMV and DecodeGEMV.""" @@ -478,7 +478,9 @@ def apply( TS, TR = 8, 64 else: TS, TR = 1, 64 - elif target.kind.name == "opencl" and "android" in str(target.host): + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" VEC_C = 8 LOAD_V_SHARED = False @@ -686,7 +688,9 @@ def apply( DEC_PACK = 8 SCALE_PACK = 4 - if target.kind.name == "opencl" and "android" in str(target.host): + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" VEC_C = 8 UNROLL = 8 @@ -756,7 +760,10 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid ): """Schedule the outer reduction block.""" # NOTE: Only Android is supported so far - if not (target.kind.name == "opencl" and "android" in str(target.host)): + if not ( + target.kind.name == "opencl" + and (("android" in str(target.host)) or ("adreno" in str(target.attrs))) + ): return None batch, s, r, c = sch.get_loops(block) len_s = get_extent(sch, s) From dc247816f0b6be770a39064286d9723df6782a86 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 21 Aug 2024 20:52:51 +0800 Subject: [PATCH 069/202] [Doc] Refactor install docs (#17287) * [Doc] Refactor install docs The major updates include: 1. remove nnpack installation guide 2. refactor building guide into step-by-step instructions * update for ci --- docs/install/from_source.rst | 421 ++++++++++++++--------------------- docs/install/index.rst | 3 +- docs/install/nnpack.rst | 118 ---------- 3 files changed, 163 insertions(+), 379 deletions(-) delete mode 100644 docs/install/nnpack.rst diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 4dc14863a83b..a963d06ab559 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -19,240 +19,239 @@ Install from Source =================== -This page gives instructions on how to build and install the TVM package from -scratch on various systems. It consists of two steps: +This page gives instructions on how to build and install the TVM package from source. -1. First build the shared library from the C++ codes (`libtvm.so` for linux, `libtvm.dylib` for macOS and `libtvm.dll` for windows). -2. Setup for the language packages (e.g. Python Package). +.. contents:: Table of Contents + :local: + :depth: 2 -To get started, download tvm source code from the `Download Page `_. +.. _install-dependencies: -Developers: Get Source from Github ----------------------------------- -You can also choose to clone the source repo from github. -It is important to clone the submodules along, with ``--recursive`` option. +Step 1. Install Dependencies +---------------------------- -.. code:: bash +Apache TVM requires the following dependencies: - git clone --recursive https://github.com/apache/tvm tvm +- CMake (>= 3.24.0) +- LLVM (recommended >= 15) +- Git +- A recent C++ compiler supporting C++ 17, at the minimum + - GCC 7.1 + - Clang 5.0 + - Apple Clang 9.3 + - Visual Studio 2019 (v16.7) +- Python (>= 3.8) +- (Optional) Conda (Strongly Recommended) -For windows users who use github tools, you can open the git shell, and type the following command. +To easiest way to manage dependency is via conda, which maintains a set of toolchains +including LLVM across platforms. To create the environment of those build dependencies, +one may simply use: .. code:: bash - git submodule init - git submodule update + # make sure to start with a fresh environment + conda env remove -n tvm-build-venv + # create the conda environment with build dependency + conda create -n tvm-build-venv -c conda-forge \ + "llvmdev>=15" \ + "cmake>=3.24" \ + git \ + python=3.11 + # enter the build environment + conda activate tvm-build-venv -.. _build-shared-library: +Step 2. Get Source from Github +------------------------------ +You can also choose to clone the source repo from github. -Build the Shared Library ------------------------- +.. code:: bash -Our goal is to build the shared libraries: + git clone --recursive https://github.com/apache/tvm tvm - - On Linux the target library are `libtvm.so` and `libtvm_runtime.so` - - On macOS the target library are `libtvm.dylib` and `libtvm_runtime.dylib` - - On Windows the target library are `libtvm.dll` and `libtvm_runtime.dll` +.. note:: + It's important to use the ``--recursive`` flag when cloning the TVM repository, which will + automatically clone the submodules. If you forget to use this flag, you can manually clone the submodules + by running ``git submodule update --init --recursive`` in the root directory of the TVM repository. -It is also possible to :ref:`build the runtime ` library only. +Step 3. Configure and Build +--------------------------- +Create a build directory and run CMake to configure the build. The following example shows how to build -The minimal building requirements for the ``TVM`` libraries are: +.. code:: bash - - A recent C++ compiler supporting C++ 17, at the minimum - - GCC 7.1 - - Clang 5.0 - - Apple Clang 9.3 - - Visual Studio 2019 (v16.7) - - CMake 3.18 or higher - - We highly recommend to build with LLVM to enable all the features. - - If you want to use CUDA, CUDA toolkit version >= 8.0 is required. If you are upgrading from an older version, make sure you purge the older version and reboot after installation. - - On macOS, you may want to install `Homebrew `_ to easily install and manage dependencies. - - Python is also required. Avoid using Python 3.9.X+ which is not `supported `_. 3.7.X+ and 3.8.X+ should be well supported however. + cd tvm + rm -rf build && mkdir build && cd build + # Specify the build configuration via CMake options + cp ../cmake/config.cmake . -To install the these minimal pre-requisites on Ubuntu/Debian like -linux operating systems, execute (in a terminal): +We want to specifically tweak the following flags by appending them to the end of the configuration file: .. code:: bash - sudo apt-get update - sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev - - -Note that the version of CMake on apt may not be sufficiently up to date; it may be necessary to install it directly from `Kitware's third-party APT repository `_. + # controls default compilation flags (Candidates: Release, Debug, RelWithDebInfo) + echo "set(CMAKE_BUILD_TYPE RelWithDebInfo)" >> config.cmake + # LLVM is a must dependency for compiler end + echo "set(USE_LLVM \"llvm-config --ignore-libllvm --link-static\")" >> config.cmake + echo "set(HIDE_PRIVATE_SYMBOLS ON)" >> config.cmake -On Fedora/CentOS and related operating systems use: + # GPU SDKs, turn on if needed + echo "set(USE_CUDA OFF)" >> config.cmake + echo "set(USE_METAL OFF)" >> config.cmake + echo "set(USE_VULKAN OFF)" >> config.cmake + echo "set(USE_OPENCL OFF)" >> config.cmake -.. code:: bash + # cuBLAS, cuDNN, cutlass support, turn on if needed + echo "set(USE_CUBLAS OFF)" >> config.cmake + echo "set(USE_CUDNN OFF)" >> config.cmake + echo "set(USE_CUTLASS OFF)" >> config.cmake - sudo dnf update - sudo dnf groupinstall -y "Development Tools" - sudo dnf install -y python-devel ncurses-compat-libs zlib-devel cmake libedit-devel libxml2-devel -Use Homebrew to install the required dependencies for macOS running either the Intel or M1 processors. You must follow the post-installation steps specified by -Homebrew to ensure the dependencies are correctly installed and configured: +.. note:: + ``HIDE_PRIVATE_SYMBOLS`` is a configuration option that enables the ``-fvisibility=hidden`` flag. + This flag helps prevent potential symbol conflicts between TVM and PyTorch. These conflicts arise due to + the frameworks shipping LLVMs of different versions. -.. code:: bash + `CMAKE_BUILD_TYPE `_ controls default compilation flag: - brew install gcc git cmake - brew install llvm - brew install python@3.8 + - ``Debug`` sets ``-O0 -g`` + - ``RelWithDebInfo`` sets ``-O2 -g -DNDEBUG`` (recommended) + - ``Release`` sets ``-O3 -DNDEBUG`` -If you are on macOS with an M1 Processor you may need to use conda to manage dependencies while building. Specifically you may need, `Miniforge `_ to ensure that the dependencies obtained using pip are compatible with M1. +Once ``config.cmake`` is edited accordingly, kick off build with the commands below: -.. code:: bash +.. code-block:: bash - brew install miniforge - conda init - conda create --name tvm python=3.8 - conda activate tvm + cmake .. && cmake --build . --parallel $(nproc) -We use cmake to build the library. -The configuration of TVM can be modified by editing `config.cmake` and/or by passing cmake flags to the command line: +.. note:: + ``nproc`` may not be available on all systems, please replace it with the number of cores on your system +A success build should produce ``libtvm`` and ``libtvm_runtime`` under ``build/`` directory. -- First, check the cmake in your system. If you do not have cmake, - you can obtain the latest version from `official website `_ -- First create a build directory, copy the ``cmake/config.cmake`` to the directory. +Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment: - .. code:: bash +- Install via environment variable - mkdir build - cp cmake/config.cmake build +.. code-block:: bash -- Edit ``build/config.cmake`` to customize the compilation options + export TVM_HOME=/path-to-tvm + export PYTHONPATH=$TVM_HOME/python:$PYTHONPATH - - On macOS, for some versions of Xcode, you need to add ``-lc++abi`` in the LDFLAGS or you'll get link errors. - - Change ``set(USE_CUDA OFF)`` to ``set(USE_CUDA ON)`` to enable CUDA backend. Do the same for other backends and libraries - you want to build for (OpenCL, RCOM, METAL, VULKAN, ...). - - To help with debugging, ensure the embedded graph executor and debugging functions are enabled with ``set(USE_GRAPH_EXECUTOR ON)`` and ``set(USE_PROFILER ON)`` - - To debug with IRs, ``set(USE_RELAY_DEBUG ON)`` and set environment variable `TVM_LOG_DEBUG`. +- Install via pip local project - .. code:: bash +.. code-block:: bash - export TVM_LOG_DEBUG="ir/transform.cc=1,relay/ir/transform.cc=1" + conda activate your-own-env + conda install python # make sure python is installed + cd /path-to-tvm/python + pip install -e . -- TVM requires LLVM for CPU codegen. We highly recommend you to build with the LLVM support on. +Step 4. Validate Installation +----------------------------- - - LLVM 4.0 or higher is needed for build with LLVM. Note that version of LLVM from default apt may lower than 4.0. - - Since LLVM takes long time to build from source, you can download pre-built version of LLVM from - `LLVM Download Page `_. +Using a compiler infrastructure with multiple language bindings could be error-prone. +Therefore, it is highly recommended to validate Apache TVM installation before use. - - Unzip to a certain location, modify ``build/config.cmake`` to add ``set(USE_LLVM /path/to/your/llvm/bin/llvm-config)`` - - You can also directly set ``set(USE_LLVM ON)`` and let cmake search for a usable version of LLVM. +**Step 1. Locate TVM Python package.** The following command can help confirm that TVM is properly installed as a python package and provide the location of the TVM python package: - - You can also use `LLVM Nightly Ubuntu Build `_ +.. code-block:: bash - - Note that apt-package append ``llvm-config`` with version number. - For example, set ``set(USE_LLVM llvm-config-10)`` if you installed LLVM 10 package + >>> python -c "import tvm; print(tvm.__file__)" + /some-path/lib/python3.11/site-packages/tvm/__init__.py - - If you are a PyTorch user, it is recommended to set ``(USE_LLVM "/path/to/llvm-config --link-static")`` and ``set(HIDE_PRIVATE_SYMBOLS ON)`` - to avoid potential symbol conflicts between different versions LLVM used by TVM and PyTorch. +**Step 2. Confirm which TVM library is used.** When maintaining multiple build or installation of TVM, it becomes important to double check if the python package is using the proper ``libtvm`` with the following command: - - On supported platforms, the `Ccache compiler wrapper `_ may be helpful for - reducing TVM's build time. There are several ways to enable CCache in TVM builds: +.. code-block:: bash - - Leave `USE_CCACHE=AUTO` in `build/config.cmake`. CCache will be used if it is found. + >>> python -c "import tvm; print(tvm._ffi.base._LIB)" + - - Ccache's Masquerade mode. This is typically enabled during the Ccache installation process. - To have TVM use Ccache in masquerade, simply specify the appropriate C/C++ compiler - paths when configuring TVM's build system. For example: - ``cmake -DCMAKE_CXX_COMPILER=/usr/lib/ccache/c++ ...``. +**Step 3. Reflect TVM build option.** Sometimes when downstream application fails, it could likely be some mistakes with a wrong TVM commit, or wrong build flags. To find it out, the following commands will be helpful: - - Ccache as CMake's C++ compiler prefix. When configuring TVM's build system, - set the CMake variable ``CMAKE_CXX_COMPILER_LAUNCHER`` to an appropriate value. - E.g. ``cmake -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ...``. +.. code-block:: bash -- We can then build tvm and related libraries. + >>> python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))" + ... # Omitted less relevant options + GIT_COMMIT_HASH: 4f6289590252a1cf45a4dc37bce55a25043b8338 + HIDE_PRIVATE_SYMBOLS: ON + USE_LLVM: llvm-config --link-static + LLVM_VERSION: 15.0.7 + USE_VULKAN: OFF + USE_CUDA: OFF + CUDA_VERSION: NOT-FOUND + USE_OPENCL: OFF + USE_METAL: ON + USE_ROCM: OFF - .. code:: bash - cd build - cmake .. - make -j4 +**Step 4. Check device detection.** Sometimes it could be helpful to understand if TVM could detect your device at all with the following commands: - - You can also use Ninja build system instead of Unix Makefiles. It can be faster to build than using Makefiles. +.. code-block:: bash - .. code:: bash + >>> python -c "import tvm; print(tvm.metal().exist)" + True # or False + >>> python -c "import tvm; print(tvm.cuda().exist)" + False # or True + >>> python -c "import tvm; print(tvm.vulkan().exist)" + False # or True - cd build - cmake .. -G Ninja - ninja +Please note that the commands above verify the presence of an actual device on the local machine for the TVM runtime (not the compiler) to execute properly. However, TVM compiler can perform compilation tasks without requiring a physical device. As long as the necessary toolchain, such as NVCC, is available, TVM supports cross-compilation even in the absence of an actual device. - - There is also a makefile in the top-level tvm directory that can - automate several of these steps. It will create the build - directory, copy the default ``config.cmake`` to the build - directory, run cmake, then run make. - The build directory can be specified using the environment - variable ``TVM_BUILD_PATH``. If ``TVM_BUILD_PATH`` is unset, the - makefile assumes that the ``build`` directory inside tvm should be - used. Paths specified by ``TVM_BUILD_PATH`` can be either - absolute paths or paths relative to the base tvm directory. - ``TVM_BUILD_PATH`` can also be set to a list of space-separated - paths, in which case all paths listed will be built. +Step 5. Extra Python Dependencies +--------------------------------- +Building from source does not ensure the installation of all necessary Python dependencies. +The following commands can be used to install the extra Python dependencies: - If an alternate build directory is used, then the environment - variable ``TVM_LIBRARY_PATH`` should be set at runtime, pointing - to the location of the compiled ``libtvm.so`` and - ``libtvm_runtime.so``. If not set, tvm will look relative to the - location of the tvm python module. Unlike ``TVM_BUILD_PATH``, - this must be an absolute path. +* Necessary dependencies: - .. code:: bash - - # Build in the "build" directory - make +.. code:: bash - # Alternate location, "build_debug" - TVM_BUILD_PATH=build_debug make + pip3 install numpy decorator attrs - # Build both "build_release" and "build_debug" - TVM_BUILD_PATH="build_debug build_release" make +* If you want to use RPC Tracker - # Use debug build - TVM_LIBRARY_PATH=~/tvm/build_debug python3 +.. code:: bash -If everything goes well, we can go to :ref:`python-package-installation` + pip3 install tornado -.. _build-with-conda: +* If you want to use auto-tuning module -Building with a Conda Environment -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. code:: bash -Conda is a very handy way to the necessary obtain dependencies needed for running TVM. -First, follow the `conda's installation guide `_ -to install miniconda or anaconda if you do not yet have conda in your system. Run the following command in a conda environment: + pip3 install tornado psutil 'xgboost>=1.1.0' cloudpickle -.. code:: bash - # Create a conda environment with the dependencies specified by the yaml - conda env create --file conda/build-environment.yaml - # Activate the created environment - conda activate tvm-build +Advanced Build Configuration +---------------------------- -The above command will install all necessary build dependencies such as cmake and LLVM. You can then run the standard build process in the last section. +Ccache +~~~~~~ +On supported platforms, the `Ccache compiler wrapper `_ may be helpful for +reducing TVM's build time, especially when building with `cutlass `_ +or `flashinfer `_. +There are several ways to enable CCache in TVM builds: -If you want to use the compiled binary outside the conda environment, -you can set LLVM to static linking mode ``set(USE_LLVM "llvm-config --link-static")``. -In this way, the resulting library won't depend on the dynamic LLVM libraries in the conda environment. + - Leave ``USE_CCACHE=AUTO`` in ``build/config.cmake``. CCache will be used if it is found. -The above instructions show how to use conda to provide the necessary build dependencies to build libtvm. -If you are already using conda as your package manager and wish to directly build and install tvm as a conda package, you can follow the instructions below: + - Ccache's Masquerade mode. This is typically enabled during the Ccache installation process. + To have TVM use Ccache in masquerade, simply specify the appropriate C/C++ compiler + paths when configuring TVM's build system. For example: + ``cmake -DCMAKE_CXX_COMPILER=/usr/lib/ccache/c++ ...``. -.. code:: bash + - Ccache as CMake's C++ compiler prefix. When configuring TVM's build system, + set the CMake variable ``CMAKE_CXX_COMPILER_LAUNCHER`` to an appropriate value. + E.g. ``cmake -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ...``. - conda build --output-folder=conda/pkg conda/recipe - # Run conda/build_cuda.sh to build with cuda enabled - conda install tvm -c ./conda/pkg Building on Windows ~~~~~~~~~~~~~~~~~~~ TVM support build via MSVC using cmake. You will need to obtain a visual studio compiler. The minimum required VS version is **Visual Studio Enterprise 2019** (NOTE: we test against GitHub Actions' `Windows 2019 Runner `_, so see that page for full details. -We recommend following :ref:`build-with-conda` to obtain necessary dependencies and +We recommend following :ref:`install-dependencies` to obtain necessary dependencies and get an activated tvm-build environment. Then you can run the following command to build .. code:: bash @@ -279,117 +278,21 @@ Currently, ROCm is supported only on linux, so all the instructions are written - You need to first install HIP runtime from ROCm. Make sure the installation system has ROCm installed in it. - Install latest stable version of LLVM (v6.0.1), and LLD, make sure ``ld.lld`` is available via command line. -.. _python-package-installation: - -Python Package Installation ---------------------------- - -TVM package -~~~~~~~~~~~ - -Depending on your development environment, you may want to use a virtual environment and package manager, such -as ``virtualenv`` or ``conda``, to manage your python packages and dependencies. - -The python package is located at `tvm/python` -There are two ways to install the package: - -Method 1 - This method is **recommended for developers** who may change the codes. - - Set the environment variable `PYTHONPATH` to tell python where to find - the library. For example, assume we cloned `tvm` on the directory - `/path/to/tvm` then we can add the following line in `~/.bashrc`. - The changes will be immediately reflected once you pull the code and rebuild the project (no need to call ``setup`` again) - - .. code:: bash - - export TVM_HOME=/path/to/tvm - export PYTHONPATH=$TVM_HOME/python:${PYTHONPATH} - - -Method 2 - Install TVM python bindings by `setup.py`: - - .. code:: bash - - # install tvm package for the current user - # NOTE: if you installed python via homebrew, --user is not needed during installaiton - # it will be automatically installed to your user directory. - # providing --user flag may trigger error during installation in such case. - export MACOSX_DEPLOYMENT_TARGET=10.9 # This is required for mac to avoid symbol conflicts with libstdc++ - cd python; python setup.py install --user; cd .. - -Python dependencies -~~~~~~~~~~~~~~~~~~~ - -Note that the ``--user`` flag is not necessary if you're installing to a managed local environment, -like ``virtualenv``. - - * Necessary dependencies: - - .. code:: bash - - pip3 install --user numpy decorator attrs - - * If you want to use ``tvmc``: the TVM command line driver. - - .. code:: bash - - pip3 install --user typing-extensions psutil scipy - - * If you want to use RPC Tracker - - .. code:: bash - - pip3 install --user tornado - - * If you want to use auto-tuning module - - .. code:: bash - - pip3 install --user tornado psutil 'xgboost>=1.1.0' cloudpickle - -Note on M1 macs, you may have trouble installing xgboost / scipy. scipy and xgboost requires some additional dependencies to be installed, -including openblas and its dependencies. Use the following commands to install scipy and xgboost with the required dependencies and -configuration. A workaround for this is to do the following commands: - - .. code:: bash - - brew install openblas gfortran - - pip install pybind11 cython pythran - - export OPENBLAS=/opt/homebrew/opt/openblas/lib/ - - pip install scipy --no-use-pep517 - - pip install 'xgboost>=1.1.0' - -Install Contrib Libraries -------------------------- - -.. toctree:: - :maxdepth: 1 - - nnpack - - .. _install-from-source-cpp-tests: Enable C++ Tests ----------------- +~~~~~~~~~~~~~~~~ We use `Google Test `_ to drive the C++ tests in TVM. The easiest way to install GTest is from source. - .. code:: bash - - git clone https://github.com/google/googletest - cd googletest - mkdir build - cd build - cmake -DBUILD_SHARED_LIBS=ON .. - make - sudo make install +.. code:: bash + git clone https://github.com/google/googletest + cd googletest + mkdir build + cd build + cmake -DBUILD_SHARED_LIBS=ON .. + make + sudo make install After installing GTest, the C++ tests can be built and started with ``./tests/scripts/task_cpp_unittest.sh`` or just built with ``make cpptest``. diff --git a/docs/install/index.rst b/docs/install/index.rst index ab2e06d0de47..6bc2da97e119 100644 --- a/docs/install/index.rst +++ b/docs/install/index.rst @@ -21,11 +21,10 @@ Installing TVM ============== .. toctree:: - :maxdepth: 2 + :maxdepth: 1 from_source docker - nnpack Visit the :ref:`install TVM from source ` page to install TVM from the source code. Installing from source gives you the maximum flexibility to configure the build effectively from the official source releases. diff --git a/docs/install/nnpack.rst b/docs/install/nnpack.rst deleted file mode 100644 index c5516235a303..000000000000 --- a/docs/install/nnpack.rst +++ /dev/null @@ -1,118 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -NNPACK Contrib Installation -=========================== - -`NNPACK `_ is an acceleration package -for neural network computations, which can run on x86-64, ARMv7, or ARM64 architecture CPUs. -Using NNPACK, higher-level libraries like _MXNet_ can speed up -the execution on multi-core CPU computers, including laptops and mobile devices. - -.. note:: - - AS TVM already has natively tuned schedules, NNPACK is here mainly for reference and comparison purpose. - For regular use prefer native tuned TVM implementation. - -TVM supports NNPACK for forward propagation (inference only) in convolution, max-pooling, and fully-connected layers. -In this document, we give a high level overview of how to use NNPACK with TVM. - -Conditions ----------- - -The underlying implementation of NNPACK utilizes several acceleration methods, -including fft and winograd. -These algorithms work better on some special `batch size`, `kernel size`, and `stride` settings than on other, -so depending on the context, not all convolution, max-pooling, or fully-connected layers can be powered by NNPACK. -When favorable conditions for running NNPACKS are not met, - -NNPACK only supports Linux and OS X systems. Windows is not supported at present. - -Build/Install NNPACK --------------------- - -If the trained model meets some conditions of using NNPACK, -you can build TVM with NNPACK support. -Follow these simple steps: - -build NNPACK shared library with the following commands. TVM will link NNPACK dynamically. - -Note: The following NNPACK installation instructions have been tested on Ubuntu 16.04. - -Build Ninja -~~~~~~~~~~~ - -NNPACK need a recent version of Ninja. So we need to install ninja from source. - -.. code:: bash - - git clone git://github.com/ninja-build/ninja.git - cd ninja - ./configure.py --bootstrap - - -Set the environment variable PATH to tell bash where to find the ninja executable. For example, assume we cloned ninja on the home directory ~. then we can added the following line in ~/.bashrc. - - -.. code:: bash - - export PATH="${PATH}:~/ninja" - - -Build NNPACK -~~~~~~~~~~~~ - -The new CMAKE version of NNPACK download `Peach `_ and other dependencies alone - -Note: at least on OS X, running `ninja install` below will overwrite googletest libraries installed in `/usr/local/lib`. If you build googletest again to replace the nnpack copy, be sure to pass `-DBUILD_SHARED_LIBS=ON` to `cmake`. - -.. code:: bash - - git clone --recursive https://github.com/Maratyszcza/NNPACK.git - cd NNPACK - # Add PIC option in CFLAG and CXXFLAG to build NNPACK shared library - sed -i "s|gnu99|gnu99 -fPIC|g" CMakeLists.txt - sed -i "s|gnu++11|gnu++11 -fPIC|g" CMakeLists.txt - mkdir build - cd build - # Generate ninja build rule and add shared library in configuration - cmake -G Ninja -D BUILD_SHARED_LIBS=ON .. - ninja - sudo ninja install - - # Add NNPACK lib folder in your ldconfig - echo "/usr/local/lib" > /etc/ld.so.conf.d/nnpack.conf - sudo ldconfig - - -Build TVM with NNPACK support ------------------------------ - -.. code:: bash - - git clone --recursive https://github.com/apache/tvm tvm - -- Set `set(USE_NNPACK ON)` in config.cmake. -- Set `NNPACK_PATH` to the $(YOUR_NNPACK_INSTALL_PATH) - -after configuration use `make` to build TVM - - -.. code:: bash - - make From b76ebad8867e36121708cf654923b66c4f7c9ede Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 21 Aug 2024 09:04:34 -0400 Subject: [PATCH 070/202] [Codegen] Emit `tir::Let` as var assignment explicitly (#17278) Prior to this PR, the PrimExpr `tir::Let` is treated as inlining during codegen, which makes any common subexpression elimination (CSE) efforts using `tir::Let` at TIR level effectless. This PR updates codegen so that the `tir::Let` will have an explicit var assignment and thus can effectively reflect the CSE efforts. --- python/tvm/relax/frontend/nn/op.py | 6 +++--- src/target/source/codegen_c.cc | 21 ++++++++++++++++++++- tests/python/relax/test_frontend_nn_op.py | 6 +++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 17a40a8cce57..04c030bea6fa 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2544,7 +2544,7 @@ def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j): @T.prim_func(private=True) def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): - batch, vocab_size = T.int64(), T.int64() + batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype) top_p = T.match_buffer(B, (batch, 1), prob_dtype) top_k = T.match_buffer(C, (batch, 1), index_dtype) @@ -2564,8 +2564,8 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _get_index_from_sorted( A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle ): - batch, vocab_size = T.int64(), T.int64() - out_batch = T.int64() + batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) + out_batch = T.int64(is_size_var=True) cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype) indices = T.match_buffer(B, (batch, vocab_size), index_dtype) renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 03c3e3af66d5..9f68cd8d669a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -887,8 +887,27 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) let_binding_[op->var] = op; } std::string value = PrintExpr(op->value); - var_idmap_[op->var.get()] = value; + if (print_ssa_form_) { + ICHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { + PrintType(handle_data_type_.at(op->var.get()), this->stream); + this->stream << "* " << AllocVarID(op->var.get()) << " = ("; + PrintType(handle_data_type_.at(op->var.get()), this->stream); + this->stream << "*)" << value << ";\n"; + } else { + PrintType(op->var.dtype(), this->stream); + this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; + } + } os << PrintExpr(op->body); + // Pop the defined var from var_idmap when exiting its scope. + // We do this because it is hard to completely avoid a same LetNode appearing + // at different places. + bool removed = var_idmap_.erase(op->var.get()); + ICHECK(removed); } void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6c3269195498..40624790cb5a 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -947,11 +947,11 @@ def foo( class Expected: @T.prim_func(private=True) def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle): - batch, vocab_size = T.int64(), T.int64() + batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) indices = T.match_buffer(B, (batch, vocab_size), "int64") renorm_prob = T.match_buffer(C, (batch, 1)) - out_batch = T.int64() + out_batch = T.int64(is_size_var=True) usample = T.match_buffer(D, (out_batch, 1)) sample_indices = T.match_buffer(E, (out_batch, 1), "int64") output_index = T.match_buffer(F, (out_batch, 1), "int64") @@ -970,7 +970,7 @@ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: @T.prim_func(private=True) def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): - batch, vocab_size = T.int64(), T.int64() + batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) top_p = T.match_buffer(B, (batch, 1)) top_k = T.match_buffer(C, (batch, 1), "int64") From 32063b0dfcb8ffcec6b7b4f99bc51adb178f1394 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 22 Aug 2024 22:24:23 +0800 Subject: [PATCH 071/202] [Doc] Quick Start (#17289) This PR introduces a new quick start tutorial to the documentation. --- docs/.gitignore | 1 - docs/conf.py | 6 + docs/get_started/tutorials/README.txt | 2 + docs/get_started/tutorials/quick_start.py | 193 ++++++++++++++++++++++ docs/index.rst | 1 + tests/scripts/task_python_docs.sh | 2 + 6 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 docs/get_started/tutorials/README.txt create mode 100644 docs/get_started/tutorials/quick_start.py diff --git a/docs/.gitignore b/docs/.gitignore index 84b247d3699c..041cf3588799 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,3 +1,2 @@ doxygen modules -tutorials diff --git a/docs/conf.py b/docs/conf.py index be1ba11aa091..c3472c15de91 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -408,6 +408,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): from sphinx_gallery.sorting import ExplicitOrder examples_dirs = [ + # legacy tutorial structure under gallery folder tvm_path.joinpath("gallery", "tutorial"), tvm_path.joinpath("gallery", "how_to", "compile_models"), tvm_path.joinpath("gallery", "how_to", "deploy_models"), @@ -419,9 +420,12 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): tvm_path.joinpath("gallery", "how_to", "work_with_microtvm"), tvm_path.joinpath("gallery", "how_to", "extend_tvm"), tvm_path.joinpath("vta", "tutorials"), + # New tutorial structure under docs folder + tvm_path.joinpath("docs", "get_started", "tutorials"), ] gallery_dirs = [ + # legacy tutorial structure under gallery folder "tutorial", "how_to/compile_models", "how_to/deploy_models", @@ -433,6 +437,8 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): "how_to/work_with_microtvm", "how_to/extend_tvm", "topic/vta/tutorials", + # New tutorial structure under docs folder + "get_started/tutorials/", ] diff --git a/docs/get_started/tutorials/README.txt b/docs/get_started/tutorials/README.txt new file mode 100644 index 000000000000..62e2c7b770fb --- /dev/null +++ b/docs/get_started/tutorials/README.txt @@ -0,0 +1,2 @@ +Get Started +----------- diff --git a/docs/get_started/tutorials/quick_start.py b/docs/get_started/tutorials/quick_start.py new file mode 100644 index 000000000000..a4edf0b7c4fe --- /dev/null +++ b/docs/get_started/tutorials/quick_start.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _quick_start: + +Quick Start +=========== + +This tutorial is for people who are new to Apache TVM. Taking an simple example +to show how to use Apache TVM to compile a simple neural network. + +.. contents:: Table of Contents + :local: + :depth: 2 + +""" + +################################################################################ +# Overview +# -------- +# Apache TVM is a machine learning compilation framework, following the principle of +# **Python-first development** and **universal deployment**. It takes in pre-trained +# machine learning models, compiles and generates deployable modules that can be embedded +# and run everywhere. +# Apache TVM also enables customizing optimization processes to introduce new optimizations, +# libraries, codegen and more. +# +# Apache TVM can help to: +# +# - **Optimize** performance of ML workloads, composing libraries and codegen. +# - **Deploy** ML workloads to a diverse set of new environments, including new runtime and new +# hardware. +# - **Continuously improve and customize** ML deployment pipeline in Python by quickly customizing +# library dispatching, bringing in customized operators and code generation. + +################################################################################ +# Overall Flow +# ------------ +# Then we will show the overall flow of using Apache TVM to compile a neural network model, +# showing how to optimize, deploy and run the model. +# The overall flow is illustrated as the figure: +# +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg +# :align: center +# :width: 80% +# +# The overall flow consists of the following steps: +# +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains +# all the information needed for compilation, including high-level Relax functions for +# computational graph, and low-level TensorIR functions for tensor program. +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, +# such as graph optimizations, tensor program optimizations, and library dispatching. +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. + +################################################################################ +# Construct or Import a Model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we get started, let's construct a neural network model first. +# In this tutorial, to make things simple, we will defined a two-layer MLP networks +# directly in this script with TVM Relax frontend, which is a similar API to PyTorch. +# + +import tvm +from tvm import relax +from tvm.relax.frontend import nn + + +class MLPModel(nn.Module): + def __init__(self): + super(MLPModel, self).__init__() + self.fc1 = nn.Linear(784, 256) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +################################################################################ +# Then we can export the model to TVM IRModule, which is the central intermediate representation +# in TVM. + +mod, param_spec = MLPModel().export_tvm( + spec={"forward": {"x": nn.spec.Tensor((1, 784), "float32")}} +) +mod.show() + +################################################################################ +# Perform Optimization Transformations +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Apache TVM leverage ``pipeline`` to transform and optimize program. +# The pipeline encapsulates a collection of transformation that gets two goals (at the same level): +# +# - **Model optimizations**: such as operator fusion, layout rewrites. +# - **Tensor program optimization**: Map the operators to low-level implementations +# (both library or codegen) +# +# .. note:: +# The twos are goals but not the stages of the pipeline. The two optimizations are performed +# **at the same level**, or separately in two stages. +# +# .. note:: +# In this tutorial we only demonstrate the overall flow, by leverage ``zero`` optimization +# pipeline, instead of optimizing for any specific target. + +mod = relax.get_pipeline("zero")(mod) + + +################################################################################ +# Build and Universal Deployment +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the optimization, we can build the model to a deployable module and run it on +# different devices. + + +import numpy as np + +target = tvm.target.Target("llvm") +ex = relax.build(mod, target) +device = tvm.cpu() +vm = relax.VirtualMachine(ex, device) +data = np.random.rand(1, 784).astype("float32") +tvm_data = tvm.nd.array(data, device=device) +params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec] +params = [tvm.nd.array(param, device=device) for param in params] +print(vm["forward"](tvm_data, *params).numpy()) + +################################################################################ +# Our goal is to bring machine learning to the application with any language of interest, +# with the minimum runtime support. +# +# - Each function in IRModule becomes a runnable function in the runtime. For example in LLM +# cases, we can call ``prefill`` and ``decode`` functions directly. +# +# .. code-block:: Python +# +# prefill_logits = vm["prefill"](inputs, weight, kv_cache) +# decoded_logits = vm["decode"](inputs, weight, kv_cache) +# +# - TVM runtime comes with native data structures, such as NDArray, can also have zero +# copy exchange with existing ecosystem (DLPack exchange with PyTorch) +# +# .. code-block:: Python +# +# # Convert PyTorch tensor to TVM NDArray +# x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack()) +# # Convert TVM NDArray to PyTorch tensor +# x_torch = torch.from_dlpack(x_tvm.to_dlpack()) +# +# - TVM runtime works in non-python environments, so it works on settings such as mobile +# +# .. code-block:: C++ +# +# // C++ snippet +# runtime::Module vm = ex.GetFunction("load_executable")(); +# vm.GetFunction("init")(...); +# NDArray out = vm.GetFunction("prefill")(data, weight, kv_cache); +# +# .. code-block:: Java +# +# // Java snippet +# Module vm = ex.getFunction("load_executable").invoke(); +# vm.getFunction("init").pushArg(...).invoke; +# NDArray out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke(); +# + +################################################################################ +# Read next +# --------- +# This tutorial demonstrates the overall flow of using Apache TVM to compile a neural network model. +# For more advanced or specific topics, please refer to the following tutorials +# diff --git a/docs/index.rst b/docs/index.rst index 95b1937671ea..7f13101f741e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ driving its costs down. :caption: Getting Started install/index + get_started/tutorials/quick_start contribute/index .. toctree:: diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 9690c330c0df..2a213ddd1843 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -90,6 +90,8 @@ IGNORED_WARNINGS=( 'absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.' 'absl:Found untraced functions such as _jit_compiled_convolution_op' 'You are using pip version' + # Tutorial READMEs can be ignored, but other docs should be included + "tutorials/README.rst: WARNING: document isn't included in any toctree" ) JOINED_WARNINGS=$(join_by '|' "${IGNORED_WARNINGS[@]}") From ed9aa56b373c60acef151d4defac44e3c2360a0a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Aug 2024 11:26:27 -0500 Subject: [PATCH 072/202] [Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224) * [Relax][Analysis] Handle recursive functions in CollectVarUsage Prior to this commit, the `relax::analysis::CollectVarUsage` utility treated a local function definition as in-scope after visiting the body of the local function. As a result, recursive calls from a local function were incorrectly identified as calls to an undefined variable. This commit updates the `CollectVarUsage` to treat a local function definition as in-scope when inspecting the function body. This change is similar to the change made for structural equality in https://github.com/apache/tvm/pull/16756. * lint fixes --- src/relax/analysis/udchain.cc | 21 ++++- .../test_transform_dead_code_elimination.py | 81 +++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index d7ab4f1031b4..65e15a4161dd 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -55,6 +55,7 @@ class UDChain : relax::ExprVisitor { private: Map bound_values; + std::unordered_set forward_declarations; std::unordered_map> usage_map; support::OrderedSet outputs; @@ -71,9 +72,20 @@ class UDChain : relax::ExprVisitor { cur_user_ = cache; } + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { + // A local Relax function may be recursively defined. References to + // `binding->var` that appear within `func` are valid. + DefineVar(binding->var); + forward_declarations.insert(binding->var); + ExprVisitor::VisitBinding_(binding, func); + } + void VisitVarDef(const Var& var) override { - CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition"; - usage_map[var] = {}; + if (forward_declarations.count(var)) { + forward_declarations.erase(var); + } else { + DefineVar(var); + } } void VisitExpr_(const VarNode* op) override { auto var = GetRef(op); @@ -89,6 +101,11 @@ class UDChain : relax::ExprVisitor { cur_user_ = nullptr; ExprVisitor::VisitExpr_(op); } + + void DefineVar(const Var& var) { + CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition"; + usage_map[var] = {}; + } }; std::pair>, runtime::Array> FunctionUseDef( diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 142faf51607b..6546d09777b0 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -658,5 +658,86 @@ def subsubroutine(A: R.Tensor) -> R.Tensor: tvm.ir.assert_structural_equal(Expected, After) +def test_recursively_defined_lambda(): + """DCE may be applied to recursively-defined functions + + While most expressions may only contain references to + previously-defined variables, local Relax function definitions may + contain references to themselves. + + This is a regression test. In previous implementations, the + recursive use of `while_loop` resulted in an error, as + `while_loop` was not considered in-scope by the `CollectVarUsage` + utility until after the body of `while_loop` had been visited. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond = R.call_pure_packed( + "test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), dtype="bool") + ) + c = R.const(1, dtype="int32") + if cond: + new_i = R.add(i, c) + new_s = R.add(s, x) + r = while_loop(new_i, new_s) + else: + r = s + return r + + gv = while_loop(R.const(0), x) + return gv + + Expected = Before + + verify(Before, Expected) + + +def test_recursively_defined_closure(): + """DCE may be applied to recursively-defined closures + + This test is identical to `test_recursively_defined_lambda`, + except that the threshold for recursion is defined in an enclosed + variable outside of the recursive function. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + threshold = R.const(10) + + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond = R.call_pure_packed( + "test.vm.less", i, threshold, sinfo_args=R.Tensor((), dtype="bool") + ) + c = R.const(1, dtype="int32") + if cond: + new_i = R.add(i, c) + new_s = R.add(s, x) + r = while_loop(new_i, new_s) + else: + r = s + return r + + gv = while_loop(R.const(0), x) + return gv + + Expected = Before + + verify(Before, Expected) + + if __name__ == "__main__": tvm.testing.main() From 20289e8502dd27c91f3945418c864ad7233aec89 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Aug 2024 12:12:56 -0500 Subject: [PATCH 073/202] [Cleanup] Remove `using namespace tvm::runtime` from headers (#17246) Prior to this commit, various header files had `using namespace tvm::runtime`, which imports all names from `tvm::runtime` into the current namespace. These imports can cause compilation errors depending on the order of `#include` statements. For example, the `#include ` file uses the unqualified name `Bool` to refer to `::tvm::Bool`, a subclass of `PrimExpr`. If a different header file specifies `using namespace tvm::runtime` within the `tvm::relay` namespace, then the unqualified name `Bool` ambiguously refers to either `::tvm::Bool` or `::tvm::runtime::Bool`. In MSVC, this can cause even further compilation errors. By default, MSVC does not follow the C++ standard for name resolution in templates. The standard requires that any names in a template that do not depend on template parameters be resolved when the template is declared. However, MSVC instead resolves these names when the template is instantiated. As a result, the same `using namespace tvm::runtime` may cause a compilation error if it occurs after the template's declaration, but before the template's usage. (TVM provides the `/permissive-` flag to MSVC builds specifically to disable MSVC's non-standard name resolution, so this only impacts downstream forks that disable this flag. See https://github.com/apache/tvm/pull/16343 for more details.) This commit removes `using namespace tvm::runtime`, replacing them with explicit `using tvm::runtime::SOME_SPECIFIC_SYMBOL` where necessary. This resolves both the include-order dependency for standards-compliant compilers, and the compilation errors for MSVC's default build. --- src/contrib/msc/core/ir/graph_builder.h | 3 ++- src/relay/backend/vm/compiler.h | 3 ++- src/relay/parser/parser.cc | 2 ++ src/relay/parser/token.h | 2 -- src/relay/parser/tokenizer.h | 2 -- src/runtime/contrib/cblas/gemm_common.h | 5 ++++- src/runtime/contrib/json/json_node.h | 1 - src/runtime/contrib/nnpack/nnpack_utils.h | 1 - src/runtime/contrib/verilator/verilator_runtime.h | 1 - 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 4b042c5617e4..d514a793475d 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -51,7 +51,8 @@ namespace msc { using Expr = tvm::RelayExpr; using RelaxExprVisitor = tvm::relax::ExprVisitor; using RelayExprVisitor = tvm::relay::ExprVisitor; -using namespace tvm::runtime; + +using tvm::runtime::NDArray; /*! * \brief Config for building MSCGraph. diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index acb4d2d1d258..d22fb3d4d5ca 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -51,7 +51,8 @@ namespace tvm { namespace relay { namespace vm { -using namespace tvm::runtime; +using tvm::runtime::ModulePropertyMask; +using tvm::runtime::NDArray; using namespace tvm::runtime::vm; using namespace relay::transform; diff --git a/src/relay/parser/parser.cc b/src/relay/parser/parser.cc index b519a1778ce0..233455bf89ba 100644 --- a/src/relay/parser/parser.cc +++ b/src/relay/parser/parser.cc @@ -48,6 +48,8 @@ namespace relay { /*! \brief The meta table maps from type key to a sequence of objects. */ using MetaTable = Map>; +using tvm::runtime::NDArray; +using tvm::runtime::String2DLDataType; using tvm::transform::CreateModulePass; using tvm::transform::PassContext; diff --git a/src/relay/parser/token.h b/src/relay/parser/token.h index 7b11e701cf6e..13875cb09391 100644 --- a/src/relay/parser/token.h +++ b/src/relay/parser/token.h @@ -36,8 +36,6 @@ namespace tvm { namespace relay { -using namespace runtime; - enum class TokenType { kCommentStart, kCommentEnd, diff --git a/src/relay/parser/tokenizer.h b/src/relay/parser/tokenizer.h index 04dcd3263e99..2b7ad4e5593e 100644 --- a/src/relay/parser/tokenizer.h +++ b/src/relay/parser/tokenizer.h @@ -41,8 +41,6 @@ namespace tvm { namespace relay { -using namespace runtime; - // trim from start (in place) static inline void ltrim(std::string& s) { // NOLINT(*) s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); })); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index af073da9ba1a..91341976bd02 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -34,7 +34,10 @@ namespace tvm { namespace contrib { -using namespace runtime; +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::TypeMatch; + inline int ColumnStride(const DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h index bafe6cfbec18..dd16c606815a 100644 --- a/src/runtime/contrib/json/json_node.h +++ b/src/runtime/contrib/json/json_node.h @@ -42,7 +42,6 @@ namespace tvm { namespace runtime { namespace json { -using namespace tvm::runtime; using JSONGraphAttrs = std::unordered_map; /*! diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 4396ea0bcde6..ed0312dac476 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -30,7 +30,6 @@ namespace tvm { namespace contrib { -using namespace runtime; struct NNPackThreadLocalEntry { pthreadpool_t threadpool{nullptr}; diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h index 9ef17d7481ab..14bf0bcdfc9b 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.h +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -43,7 +43,6 @@ namespace tvm { namespace runtime { namespace contrib { -using namespace tvm::runtime; using namespace tvm::runtime::contrib; using namespace tvm::runtime::json; From 0f037a6d9957108decceaf0c91bd84667a077aad Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Aug 2024 12:13:16 -0500 Subject: [PATCH 074/202] [FFI][Runtime] Use TVMValue::v_int64 to represent boolean values (#17240) * [FFI][Runtime] Use TVMValue::v_int64 to represent boolean values This is a follow-up to https://github.com/apache/tvm/pull/16183, which added handling of boolean values in the TVM FFI. The initial implementation added both a new type code (`kTVMArgBool`) and a new `TVMValue::v_bool` variant. This commit removes the `TVMValue::v_bool` variant, since the `kTVMArgBool` type code is sufficient to handle boolean arguments. Removing the `TVMValue::v_bool` variant also makes all `TVMValue` variants be 64-bit (assuming a 64-bit CPU). This can simplify debugging in some cases, since it prevents partial values from inactive variants from being present in memory. * Update MakePackedAPI, less special handling required for boolean --- include/tvm/runtime/c_runtime_api.h | 1 - include/tvm/runtime/packed_func.h | 10 +++++----- python/tvm/_ffi/_cython/packed_func.pxi | 4 ++-- rust/tvm-sys/src/packed_func.rs | 4 ++-- src/runtime/crt/common/crt_runtime_api.c | 4 +--- src/runtime/minrpc/rpc_reference.h | 4 ++-- src/target/llvm/codegen_cpu.cc | 2 +- src/tir/transforms/ir_utils.h | 3 +-- src/tir/transforms/make_packed_api.cc | 20 +++++++------------ .../codegen/test_target_codegen_llvm.py | 16 +++++++++++++++ .../test_tir_transform_make_packed_api.py | 12 ++--------- 11 files changed, 39 insertions(+), 41 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b4c653a0a59e..d26c95e4f53c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -209,7 +209,6 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; - bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 91e53055b708..7c1b08e49002 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -669,7 +669,7 @@ class TVMPODValue_ { // conversions. This is publicly exposed, as it can be useful in // specializations of PackedFuncValueConverter. if (type_code_ == kTVMArgBool) { - return value_.v_bool; + return static_cast(value_.v_int64); } else { return std::nullopt; } @@ -1041,7 +1041,7 @@ class TVMRetValue : public TVMPODValue_CRTP_ { TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kTVMArgBool); - value_.v_bool = value; + value_.v_int64 = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -1831,7 +1831,7 @@ class TVMArgsSetter { type_codes_[i] = kDLInt; } TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { - values_[i].v_bool = value; + values_[i].v_int64 = value; type_codes_[i] = kTVMArgBool; } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { @@ -2142,7 +2142,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { std::is_base_of_v) { if (std::is_base_of_v || ptr->IsInstance()) { - values_[i].v_bool = static_cast(ptr)->value; + values_[i].v_int64 = static_cast(ptr)->value; type_codes_[i] = kTVMArgBool; return; } @@ -2327,7 +2327,7 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { if constexpr (std::is_base_of_v) { if (type_code_ == kTVMArgBool) { - return Bool(value_.v_bool); + return Bool(value_.v_int64); } } diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 7977f37d0be5..6e062ab5f199 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -121,7 +121,7 @@ cdef inline int make_arg(object arg, elif isinstance(arg, bool): # A python `bool` is a subclass of `int`, so this check # must occur before `Integral`. - value[0].v_bool = arg + value[0].v_int64 = arg tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg @@ -215,7 +215,7 @@ cdef inline object make_ret(TVMValue value, int tcode): elif tcode == kTVMNullptr: return None elif tcode == kTVMArgBool: - return value.v_bool + return bool(value.v_int64) elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 2c1f7db6adb0..3d78ce52d621 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -96,7 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -119,7 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), + Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 04d36ad8bcab..2df37205b89c 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -362,10 +362,8 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r return kTvmErrorFunctionCallWrongArgType; } - if (type_codes[2] == kDLInt) { + if (type_codes[2] == kDLInt || type_codes[2] == kTVMArgBool) { query_imports = args[2].v_int64 != 0; - } else if (type_codes[2] == kTVMArgBool) { - query_imports = args[2].v_bool; } else { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 485ebdb449da..13c1fa4b38d3 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -326,7 +326,7 @@ struct RPCReference { break; } case kTVMArgBool: { - channel->template Write(value.v_bool); + channel->template Write(value.v_int64); break; } case kTVMDataType: { @@ -437,7 +437,7 @@ struct RPCReference { break; } case kTVMArgBool: { - channel->template Read(&(value.v_bool)); + channel->template Read(&(value.v_int64)); break; } case kTVMDataType: { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21899a12c4b0..b9e18bc4f8d2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -1379,7 +1379,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); if (op->dtype == DataType::Bool()) { - struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); + struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value); } return struct_value; diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 2948773321dd..05345aab8628 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,8 +155,7 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; - if (t.is_bool()) return DataType::Bool(); - if (t.is_uint() || t.is_int()) return DataType::Int(64); + if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9f2f1295fece..cf388630fcf6 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -81,7 +81,11 @@ class ReturnRewriter : public StmtMutator { // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); - if (dtype.is_int() || dtype.is_uint()) { + if (dtype.is_bool()) { + info.tcode = kTVMArgBool; + info.expr = Cast(DataType::Int(64), val); + + } else if (dtype.is_int() || dtype.is_uint()) { info.tcode = kTVMArgInt; info.expr = Cast(DataType::Int(64), val); } else if (dtype.is_float()) { @@ -340,12 +344,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back( AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgBool, - f_arg_value(DataType::Bool(), i), - cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), - }); + arg_value = cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; @@ -353,12 +352,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back( AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgInt, - f_arg_value(t, i), - cast(t, f_arg_value(DataType::Bool(), i)), - }); + arg_value = f_arg_value(t, i); } else { ICHECK(t.is_float()); std::ostringstream msg; diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index d9a6fd6e62d1..e8036467ffb6 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1179,5 +1179,21 @@ def func(arg: T.bool) -> T.int32: assert output == 20 +def test_bool_return_value(): + """Booleans may be returned from a PrimFunc""" + + @T.prim_func + def func(value: T.int32) -> T.bool: + T.func_attr({"target": T.target("llvm")}) + return value < 10 + + built = tvm.build(func) + assert isinstance(built(0), bool) + assert built(0) + + assert isinstance(built(15), bool) + assert not built(15) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 0b43db56f300..f783ab2fcef1 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -444,11 +444,7 @@ def main( arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) arg_code: T.int32 = arg_type_ids_1[0] assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" - arg: T.int32 = T.if_then_else( - arg_code == 0, - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), - ) + arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")) with T.attr(0, "compute_scope", "main_compute_"): out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) @@ -510,11 +506,7 @@ def main( arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) arg_code: T.int32 = arg_type_ids_1[0] assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" - arg: T.bool = T.if_then_else( - arg_code == 15, - T.tvm_struct_get(args, 0, 12, "bool"), - T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), - ) + arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")) with T.attr(0, "compute_scope", "main_compute_"): out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) From 8db545dddd09e1cb892d3efc8f5859acaf52482a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 22 Aug 2024 13:33:04 -0400 Subject: [PATCH 075/202] [ROCm] hipBLAS integration (#17290) This commit integrates hipBLAS into TVM. The minimum ROCm version requirement is 6.0. Co-authored-by: Lesheng Jin --- CMakeLists.txt | 1 + cmake/modules/LibInfo.cmake | 1 + cmake/modules/ROCM.cmake | 12 + cmake/utils/FindROCM.cmake | 4 + python/tvm/contrib/hipblas.py | 86 ++++ python/tvm/relax/backend/contrib/hipblas.py | 180 +++++++ python/tvm/testing/utils.py | 3 + src/relax/backend/contrib/hipblas/codegen.cc | 110 +++++ src/runtime/contrib/hipblas/hipblas.cc | 456 ++++++++++++++++++ .../contrib/hipblas/hipblas_json_runtime.cc | 153 ++++++ src/runtime/contrib/hipblas/hipblas_utils.cc | 78 +++ src/runtime/contrib/hipblas/hipblas_utils.h | 155 ++++++ src/support/libinfo.cc | 1 + tests/python/contrib/test_hipblas.py | 109 +++++ tests/python/relax/test_codegen_hipblas.py | 165 +++++++ 15 files changed, 1514 insertions(+) create mode 100644 python/tvm/contrib/hipblas.py create mode 100644 python/tvm/relax/backend/contrib/hipblas.py create mode 100644 src/relax/backend/contrib/hipblas/codegen.cc create mode 100644 src/runtime/contrib/hipblas/hipblas.cc create mode 100644 src/runtime/contrib/hipblas/hipblas_json_runtime.cc create mode 100644 src/runtime/contrib/hipblas/hipblas_utils.cc create mode 100644 src/runtime/contrib/hipblas/hipblas_utils.h create mode 100644 tests/python/contrib/test_hipblas.py create mode 100644 tests/python/relax/test_codegen_hipblas.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fba5355f077..aa2a385683d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,7 @@ tvm_option(USE_THRUST "Build with Thrust" OFF) tvm_option(USE_CURAND "Build with cuRAND" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) +tvm_option(USE_HIPBLAS "Build with ROCM:HIPBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_LIBTORCH "Build with libtorch support" OFF) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index c4637a0c17f7..da9bc3e1c9d3 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -116,6 +116,7 @@ function(add_lib_info src_file) TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE="${TVM_DEBUG_WITH_ABI_CHANGE}" TVM_INFO_TVM_LOG_BEFORE_THROW="${TVM_LOG_BEFORE_THROW}" TVM_INFO_USE_ROCBLAS="${USE_ROCBLAS}" + TVM_INFO_USE_HIPBLAS="${USE_HIPBLAS}" TVM_INFO_USE_ROCM="${USE_ROCM}" TVM_INFO_USE_RCCL="${USE_RCCL}" TVM_INFO_USE_RPC="${USE_RPC}" diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 02c4c739934a..4d0f76d6871f 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -53,6 +53,18 @@ if(USE_ROCM) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY}) endif(USE_ROCBLAS) + if(USE_HIPBLAS) + message(STATUS "Build with HIPBLAS support") + tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRC src/relax/backend/contrib/hipblas/*.cc) + list(APPEND COMPILER_SRCS ${HIPBLAS_CONTRIB_SRC}) + tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRCS src/runtime/contrib/hipblas/*.cc) + list(APPEND RUNTIME_SRCS ${HIPBLAS_CONTRIB_SRCS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLAS_LIBRARY}) + if(NOT ROCM_HIPBLASLT_LIBRARY STREQUAL "ROCM_HIPBLASLT_LIBRARY-NOTFOUND") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLASLT_LIBRARY}) + endif() + endif(USE_HIPBLAS) + if(USE_THRUST) message(STATUS "Build with rocThrust support") # We need to override CXX to hipcc. This is required by rocthrust diff --git a/cmake/utils/FindROCM.cmake b/cmake/utils/FindROCM.cmake index 4d895ff89d13..6f54c179ee76 100644 --- a/cmake/utils/FindROCM.cmake +++ b/cmake/utils/FindROCM.cmake @@ -55,6 +55,8 @@ macro(find_rocm use_rocm) endif() find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib) find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib) + find_library(ROCM_HIPBLAS_LIBRARY hipblas ${__rocm_sdk}/lib) + find_library(ROCM_HIPBLASLT_LIBRARY hipblaslt ${__rocm_sdk}/lib) find_library(ROCM_HSA_LIBRARY hsa-runtime64 ${__rocm_sdk}/lib) if(ROCM_HIPHCC_LIBRARY) @@ -66,5 +68,7 @@ macro(find_rocm use_rocm) message(STATUS "Found ROCM_HIPHCC_LIBRARY=" ${ROCM_HIPHCC_LIBRARY}) message(STATUS "Found ROCM_MIOPEN_LIBRARY=" ${ROCM_MIOPEN_LIBRARY}) message(STATUS "Found ROCM_ROCBLAS_LIBRARY=" ${ROCM_ROCBLAS_LIBRARY}) + message(STATUS "Found ROCM_HIPBLAS_LIBRARY=" ${ROCM_HIPBLAS_LIBRARY}) + message(STATUS "Found ROCM_HIPBLASLT_LIBRARY=" ${ROCM_HIPBLASLT_LIBRARY}) endif(ROCM_FOUND) endmacro(find_rocm) diff --git a/python/tvm/contrib/hipblas.py b/python/tvm/contrib/hipblas.py new file mode 100644 index 000000000000..f1e46a2caab1 --- /dev/null +++ b/python/tvm/contrib/hipblas.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to hipBLAS libraries.""" +import tvm +from tvm import te + + +def matmul(lhs, rhs, transa=False, transb=False, dtype=None): + """Create an extern op that compute matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + dtype = dtype if dtype is not None else lhs.dtype + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.hipblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + dtype=dtype, + name="matmul_hipblas", + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): + """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + dtype = dtype if dtype is not None else lhs.dtype + return te.extern( + (b, n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.hipblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb + ), + dtype=dtype, + name="batch_matmul_hipblas", + ) diff --git a/python/tvm/relax/backend/contrib/hipblas.py b/python/tvm/relax/backend/contrib/hipblas.py new file mode 100644 index 000000000000..c0accc1473e1 --- /dev/null +++ b/python/tvm/relax/backend/contrib/hipblas.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for hipblas backend""" +import operator +from functools import reduce + +import tvm +from tvm.relax import transform +from tvm.relax.transform import PatternCheckContext + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_matmul_pattern +from ..utils import has_leaking_intermediate_variables + + +def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument + """Check if dtypes in the given workload are supported by hipblas BYOC.""" + if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8' + # return out_dtype != "e5m2_float8" + return False + return (lhs_dtype == "float16" and rhs_dtype == "float16") or ( + lhs_dtype == "int8" and rhs_dtype == "int8" + ) + + +def _check_matmul(context: PatternCheckContext) -> bool: + if has_leaking_intermediate_variables(context): + return False + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + matmul_call = context.annotated_expr["root"] + + lhs_dtype = lhs.struct_info.dtype + rhs_dtype = rhs.struct_info.dtype + out_dtype = matmul_call.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): + return False + + lhs_shape = lhs.struct_info.shape.values + rhs_shape = rhs.struct_info.shape.values + + if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)): + # Reduction axis must be constant + return False + + if lhs_dtype == "int8" and rhs_dtype == "int8": + return False + elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + return False + + lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) + rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) + + if "bias" in context.annotated_expr: + if lhs_dtype == "int8" and rhs_dtype == "int8": + # Non-default epilogue not supported for IGEMM + return False + bias = context.annotated_expr["bias"] + bias_shape = bias.struct_info.shape.values + bias_batches = reduce(operator.mul, bias_shape[:-1], 1) + if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1: + # hipblas only supports bias vector + return False + + # hipblasLt does not seem to support batched GEMM with one of matrices having + # one batch (with batch_stride 0). So for batched GEMM, the two batch counts + # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by + # flattening all batch axes into the M axis. + return ( + isinstance(lhs_batches, tvm.tir.Var) + or isinstance(rhs_batches, tvm.tir.Var) + or (int(lhs_batches) == int(rhs_batches)) + or (lhs_batches >= 1 and rhs_batches == 1) + ) + + +register_patterns( + [ + ( + "hipblas.matmul", + *make_matmul_pattern( + with_bias=False, + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias", + *make_matmul_pattern( + with_bias=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + _check_matmul, + ), + ( + "hipblas.matmul_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed", + *make_matmul_pattern( + with_bias=False, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias", + *make_matmul_pattern( + with_bias=True, + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias_relu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + transposed_rhs=True, + ), + _check_matmul, + ), + ( + "hipblas.matmul_transposed_bias_gelu", + *make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + transposed_rhs=True, + ), + _check_matmul, + ), + ] +) + + +def partition_for_hipblas(mod): + """ + Partition the input module into hipblas-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + offloaded to the hipblas backend. + """ + + patterns = get_patterns_with_prefix("hipblas") + return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 64eaccb410c8..8227530f7ab7 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -949,6 +949,9 @@ def _multi_gpu_exists(): parent_features="rocm", ) +# Mark a test as requiring the hipBLAS library. +requires_hipblas = Feature("hipblas", "hipBLAS", cmake_flag="USE_HIPBLAS", parent_features="rocm") + # Mark a test as requiring the metal runtime requires_metal = Feature( "metal", diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc new file mode 100644 index 000000000000..7de5c50a614d --- /dev/null +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/hipblas/codegen.cc + * \brief Implementation of the HIPBLAS JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class HipblasJSONSerializer : public JSONSerializer { + public: + HipblasJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs_tmp; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); + } + + ICHECK(inputs_tmp.size() <= 3); + NodeEntries inputs(inputs_tmp.size()); + + auto arg_idx = backend::ExtractArgIdx(composite_name, fn); + inputs[0] = inputs_tmp[arg_idx["lhs"]->value]; + inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; + if (inputs_tmp.size() == 3) { + inputs[2] = inputs_tmp[arg_idx["bias"]->value]; + } + + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array HipblasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.HipblasJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find HIPBLAS runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.hipblas").set_body_typed(HipblasCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc new file mode 100644 index 000000000000..c135a2855d89 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas library call. + */ +#include +#include +#include + +#include "../../3rdparty/compiler-rt/builtin_fp16.h" +#include "../cblas/gemm_common.h" +#include "hipblas_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; +inline hipblasOperation_t HIPBLASBooleanToTranspose(bool item) { + return item ? HIPBLAS_OP_T : HIPBLAS_OP_N; +} + +struct HipblasHgemmOp { + typedef hipblasHalf TDatatype; + hipblasHandle_t handle; + explicit HipblasHgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, hipblasHalf alpha, hipblasHalf* A, int lda, + hipblasHalf* B, int ldb, hipblasHalf beta, hipblasHalf* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasHgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasSgemmOp { + typedef float TDatatype; + hipblasHandle_t handle; + explicit HipblasSgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasSgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasDgemmOp { + typedef double TDatatype; + hipblasHandle_t handle; + explicit HipblasDgemmOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasDgemm(handle, HIPBLASBooleanToTranspose(ta), + HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, lda, B, ldb, + &beta, C, ldc)); + } +}; + +struct HipblasHgemmBatchOp { + typedef hipblasHalf TDatatype; + hipblasHandle_t handle; + explicit HipblasHgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, hipblasHalf alpha, + hipblasHalf* A, int a_stride, int lda, hipblasHalf* B, int b_stride, int ldb, + hipblasHalf beta, hipblasHalf* C, int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasHgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +struct HipblasSgemmBatchOp { + typedef float TDatatype; + hipblasHandle_t handle; + explicit HipblasSgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasSgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +struct HipblasDgemmBatchOp { + typedef double TDatatype; + hipblasHandle_t handle; + explicit HipblasDgemmBatchOp(hipblasHandle_t hdl) : handle(hdl) {} + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CHECK_HIPBLAS_ERROR(hipblasDgemmStridedBatched( + handle, HIPBLASBooleanToTranspose(ta), HIPBLASBooleanToTranspose(tb), M, N, K, &alpha, A, + lda, a_stride, B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); + } +}; + +// Check supported mix-precision computation type and return computeType +bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) { + if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { + return TypeMatch(in_dtype, kDLInt, 8); + } else if (TypeMatch(out_dtype, kDLFloat, 32)) { + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); + } else { + return false; + } +} + +void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, + hipblasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, + const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, + bool transb, void* workspace_ptr, size_t workspace_size, + hipblasLtEpilogue_t epilogue) { + ICHECK(TypeEqual(A->dtype, B->dtype)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + auto compute_type = HIPBLAS_COMPUTE_32F; + auto scale_type = HIP_R_32F; + hipDataType ab_type = HIP_R_32F; + hipDataType c_type = HIP_R_32F; + float one_fp32 = 1.0; + float zero_fp32 = 0.0; + int32_t one_i32 = 1; + int32_t zero_i32 = 0; + void* alpha = &one_fp32; + void* beta = &zero_fp32; + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + ab_type = HIP_R_16F; + } else if (TypeMatch(A->dtype, kDLInt, 8)) { + ab_type = HIP_R_8I; + } + + if (TypeMatch(C->dtype, kDLFloat, 16)) { + c_type = HIP_R_16F; + } else if (TypeMatch(C->dtype, kDLInt, 32)) { + c_type = HIP_R_32I; + compute_type = HIPBLAS_COMPUTE_32I; + scale_type = HIP_R_32I; + alpha = &one_i32; + beta = &zero_i32; + } + + hipblasLtMatmulDesc_t op_desc; + hipblasOperation_t op_transa = HIPBLASBooleanToTranspose(transa); + hipblasOperation_t op_transb = HIPBLASBooleanToTranspose(transb); + + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_TRANSA, + &op_transb, sizeof(op_transb))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_TRANSB, + &op_transa, sizeof(op_transa))); + + if (bias != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, sizeof(float*))); + } + + if (epilogue != HIPBLASLT_EPILOGUE_DEFAULT) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(op_desc, HIPBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); + } + + int batch_offset_A = A->ndim - 2; + int batch_offset_B = B->ndim - 2; + + int M = ColumnCount(B, transb, batch_offset_B); + int N = RowCount(A, transa, batch_offset_A); + int K = ColumnCount(A, transa, batch_offset_A); + bool use_batched_gemm = A->ndim > 2 || B->ndim > 2; + + // If A is batched but B is not, flatten all non-reduction axes of A to use the regular GEMM. + // This trick is only applicable if batch axes and the other spatial axis (M or N) are + // adjacent in both the input and the output matrix. In particular, if A is of shape (M, K) + // and B matrix is of shape (Batch, N, K) with transb = true, the output shape + // is (Batch, M, N). Since the Batch and the N axes are not adjacent in the output, we cannot + // use the regular GEMM if only B is batched. + if (A->ndim > 2 && B->ndim == 2 && transa == false) { + N = 1; + for (int i = 0; i < A->ndim - 1; ++i) { + N *= A->shape[i]; + } + use_batched_gemm = false; + } + + int lda = transb ? K : M; + int ldb = transa ? N : K; + int ldc = M; + + hipblasLtMatrixLayout_t A_desc, B_desc, C_desc; + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc)); + + if (use_batched_gemm) { + auto get_batch_count = [](int64_t* shape, int batch_offset) { + int64_t count = 1; + for (int i = 0; i < batch_offset; ++i) { + count *= shape[i]; + } + return count; + }; + auto set_batch = [](hipblasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_desc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_HIPBLAS_ERROR( + hipblasLtMatrixLayoutSetAttribute(mat_desc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; + + int batch_count_A = get_batch_count(A->shape, batch_offset_A); + int batch_count_B = get_batch_count(B->shape, batch_offset_B); + int batch_count_C = get_batch_count(C->shape, C->ndim - 2); + int64_t batch_stride_A = M * K; + int64_t batch_stride_B = K * N; + int64_t batch_stride_C = M * N; + + // hipBLASLt does not seem to support batched GEMM with one of matrices having + // one batch (with batch_stride 0). + ICHECK_EQ(batch_count_A, batch_count_B); + + set_batch(A_desc, batch_count_A, batch_stride_A); + set_batch(B_desc, batch_count_B, batch_stride_B); + set_batch(C_desc, batch_count_C, batch_stride_C); + } + + auto A_data = static_cast(A->data) + A->byte_offset; + auto B_data = static_cast(B->data) + B->byte_offset; + auto C_data = static_cast(C->data) + C->byte_offset; + + hipblasLtMatmulPreferenceSetAttribute(matmul_pref_desc, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(size_t)); + + hipblasLtMatmulHeuristicResult_t heuristic_result = {}; + int returned_result = 0; + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc, + matmul_pref_desc, 1, &heuristic_result, + &returned_result)); + if (returned_result == 0) { + CHECK_HIPBLAS_ERROR(HIPBLAS_STATUS_NOT_SUPPORTED); + } + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta, + C_data, C_desc, C_data, C_desc, &heuristic_result.algo, + workspace_ptr, workspace_size, stream)); + + hipblasLtMatmulDescDestroy(op_desc); + hipblasLtMatrixLayoutDestroy(A_desc); + hipblasLtMatrixLayoutDestroy(B_desc); + hipblasLtMatrixLayoutDestroy(C_desc); +} + +inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + + ICHECK_EQ(ElementStride(A), 1); + ICHECK_EQ(ElementStride(B), 1); + ICHECK_EQ(ElementStride(C), 1); + + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // C can never be transposed. + ICHECK(!IsInPlaceTransposed(C)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + + hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); + hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); + hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; + void *alpha_ptr = nullptr, *beta_ptr = nullptr; + auto alpha_int = static_cast(alpha); + auto beta_int = static_cast(beta); + auto alpha_float = static_cast(alpha); + auto beta_float = static_cast(beta); + if (C->dtype.code == kDLInt) { + alpha_ptr = &alpha_int; + beta_ptr = &beta_int; + } else if (C->dtype.code == kDLFloat) { + alpha_ptr = &alpha_float; + beta_ptr = &beta_float; + } + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_HIPBLAS_ERROR( + hipblasGemmEx(hdl, HIPBLASBooleanToTranspose(transb), HIPBLASBooleanToTranspose(transa), + ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), alpha_ptr, + B_data, hip_in_type, ColumnStride(B), A_data, hip_in_type, ColumnStride(A), + beta_ptr, C_data, hip_out_type, ColumnStride(C), hip_out_type, algo)); +} + +inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + ICHECK_EQ(A->ndim, 3); + ICHECK_EQ(B->ndim, 3); + ICHECK_EQ(C->ndim, 3); + + int batch_size = BatchCount3D(C); + ICHECK_EQ(ElementStride3D(A), 1); + ICHECK_EQ(ElementStride3D(B), 1); + ICHECK_EQ(ElementStride3D(C), 1); + + ICHECK(TypeEqual(A->dtype, B->dtype)); + + // C can never be transposed. + ICHECK(!IsInPlaceTransposed3D(C)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; + + ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; + ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } + + hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); + hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); + hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; + void *alpha_ptr = nullptr, *beta_ptr = nullptr; + auto alpha_int = static_cast(alpha); + auto beta_int = static_cast(beta); + auto alpha_float = static_cast(alpha); + auto beta_float = static_cast(beta); + if (C->dtype.code == kDLInt) { + alpha_ptr = &alpha_int; + beta_ptr = &beta_int; + } else if (C->dtype.code == kDLFloat) { + alpha_ptr = &alpha_float; + beta_ptr = &beta_float; + } + + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + CHECK_HIPBLAS_ERROR(hipblasGemmStridedBatchedEx( + hdl, HIPBLASBooleanToTranspose(transb), HIPBLASBooleanToTranspose(transa), + ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, + hip_in_type, ColumnStride3D(B), B_stride, A_data, hip_in_type, ColumnStride3D(A), A_stride, + beta_ptr, C_data, hip_out_type, ColumnStride3D(C), C_stride, batch_size, hip_out_type, algo)); +} + +// matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallGemm(args, ret, HipblasSgemmOp(entry_ptr->handle)); + } else { + CallGemm(args, ret, HipblasDgemmOp(entry_ptr->handle)); + } + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; + + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + + if (TypeEqual(A->dtype, C->dtype)) { + ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 16)) { + CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); + } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, HipblasSgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemm(args, ret, HipblasDgemmBatchOp(entry_ptr->handle)); + } + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } + }); + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc new file mode 100644 index 000000000000..a6e7949e4559 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/hipblas/hipblas_json_runtime.cc + * \brief A simple JSON runtime for HIPBLAS. + */ + +#include +#include + +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" +#include "hipblas_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { +using namespace tvm::runtime; +using namespace tvm::runtime::json; +class HipblasJSONRuntime : public JSONRuntimeBase { + public: + HipblasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + void Init(const Array& consts) override {} + + PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime + // can be used by multiple GPUs running on different threads, we avoid using that function + // and directly call hipBLAS on the inputs from TVMArgs. + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + this->Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + + const char* type_key() const override { return "hipblas_json"; } // May be overridden + + void Run(TVMArgs args) { + auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); + + auto func = tvm::runtime::Registry::Get("runtime.get_rocm_stream"); + ICHECK(func != nullptr); + hipStream_t stream = static_cast((*func)().operator void*()); + + std::vector dl_tensors(NumEntries()); + + for (size_t i = 0; i < static_cast(args.size()); i++) { + auto eid = i < input_var_eid_.size() ? input_var_eid_[i] + : EntryID(outputs_[i - input_var_eid_.size()]); + ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor as inputs"; + + const DLTensor* arg; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + arg = arr.operator->(); + } else { + arg = args[i].operator DLTensor*(); + } + + dl_tensors[eid] = arg; + } + + auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { + ICHECK_LT(idx, node.GetInputs().size()); + auto eid = EntryID(node.GetInputs()[idx]); + ICHECK(eid < dl_tensors.size()); + return dl_tensors[eid]; + }; + + auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) { + const DLTensor* bias = nullptr; + if (has_bias) { + bias = get_input(node, 2); + } + return std::make_tuple(get_input(node, 0), get_input(node, 1), bias); + }; + + for (size_t i = 0; i < nodes_.size(); ++i) { + const auto& node = nodes_[i]; + if (node.GetOpType() == "kernel") { + auto op_name = node.GetOpName(); + uint32_t output_eid = EntryID(outputs_[0]); + auto out_ptr = dl_tensors[output_eid]; + bool transa = false; + bool transb = false; + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + + if (op_name.find("transposed") != std::string::npos) { + transb = true; + } + + if (op_name.find("relu") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_RELU_BIAS; + } else if (op_name.find("gelu") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_GELU_BIAS; + } else if (op_name.find("bias") != std::string::npos) { + epilogue = HIPBLASLT_EPILOGUE_BIAS; + } + + auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != HIPBLASLT_EPILOGUE_DEFAULT); + + tvm::contrib::CallHipblasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr, + b_ptr, bias_ptr, out_ptr, transa, transb, + entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue); + } + } + } + + void Run() override { LOG(FATAL) << "Unreachable"; } +}; + +runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate").set_body_typed(HipblasJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc new file mode 100644 index 000000000000..02d91646518c --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas utils function + */ +#include "hipblas_utils.h" + +#include +#include + +#include "../../rocm/rocm_common.h" + +namespace tvm { +namespace contrib { + +HipBlasThreadEntry::HipBlasThreadEntry() { CHECK_HIPBLAS_ERROR(hipblasCreate(&handle)); } + +HipBlasThreadEntry::~HipBlasThreadEntry() { + if (handle) { + hipblasDestroy(handle); + handle = nullptr; + } +} + +typedef dmlc::ThreadLocalStore HipBlasThreadStore; + +HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal() { + auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; + HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); + CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); + return retval; +} + +HipBlasLtThreadEntry::HipBlasLtThreadEntry() { + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&matmul_pref_desc)); + ROCM_CALL(hipMalloc(&workspace_ptr, workspace_size)); +} + +HipBlasLtThreadEntry::~HipBlasLtThreadEntry() { + if (handle) { + hipblasLtDestroy(handle); + handle = nullptr; + } + if (matmul_pref_desc) { + hipblasLtMatmulPreferenceDestroy(matmul_pref_desc); + matmul_pref_desc = nullptr; + } + if (workspace_ptr != nullptr) { + hipFree(workspace_ptr); + workspace_ptr = nullptr; + } +} + +typedef dmlc::ThreadLocalStore HipBlasLtThreadStore; + +HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal() { return HipBlasLtThreadStore::Get(); } + +} // namespace contrib + +} // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h b/src/runtime/contrib/hipblas/hipblas_utils.h new file mode 100644 index 000000000000..66d7afafbd64 --- /dev/null +++ b/src/runtime/contrib/hipblas/hipblas_utils.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external hipblas utils function + */ +#ifndef TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace contrib { +inline const char* GetHipblasErrorString(int error) { + switch (error) { + case HIPBLAS_STATUS_NOT_INITIALIZED: + return "HIPBLAS_STATUS_NOT_INITIALIZED"; + case HIPBLAS_STATUS_ALLOC_FAILED: + return "HIPBLAS_STATUS_ALLOC_FAILED"; + case HIPBLAS_STATUS_INVALID_VALUE: + return "HIPBLAS_STATUS_INVALID_VALUE"; + case HIPBLAS_STATUS_ARCH_MISMATCH: + return "HIPBLAS_STATUS_ARCH_MISMATCH"; + case HIPBLAS_STATUS_MAPPING_ERROR: + return "HIPBLAS_STATUS_MAPPING_ERROR"; + case HIPBLAS_STATUS_EXECUTION_FAILED: + return "HIPBLAS_STATUS_EXECUTION_FAILED"; + case HIPBLAS_STATUS_INTERNAL_ERROR: + return "HIPBLAS_STATUS_INTERNAL_ERROR"; + case HIPBLAS_STATUS_NOT_SUPPORTED: + return "HIPBLAS_STATUS_NOT_SUPPORTED"; + } + return "Unrecognized error"; +} + +#ifndef CHECK_HIPBLAS_ERROR +#define CHECK_HIPBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ + ICHECK_EQ(error, HIPBLAS_STATUS_SUCCESS) << "HIPBLAS: " << GetHipblasErrorString(error); \ + } while (0) // ; intentionally left off. +#endif // CHECK_HIPBLAS_ERROR + +struct HipBlasThreadEntry { + HipBlasThreadEntry(); + ~HipBlasThreadEntry(); + hipblasHandle_t handle{nullptr}; + static HipBlasThreadEntry* ThreadLocal(); +}; // HipBlasThreadEntry + +struct HipBlasLtThreadEntry { + HipBlasLtThreadEntry(); + ~HipBlasLtThreadEntry(); + + hipblasLtHandle_t handle{nullptr}; + hipblasLtMatmulPreference_t matmul_pref_desc{nullptr}; + void* workspace_ptr{nullptr}; + // 32MB workspace as suggested by NVIDIA + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace. + static constexpr const size_t workspace_size = 33554432; + + static HipBlasLtThreadEntry* ThreadLocal(); +}; // HipBlasLtThreadEntry + +inline hipDataType GetHipDataType(DLDataType type) { + if (type.code == kDLInt) { + switch (type.bits) { + case 8: + return HIP_R_8I; + case 32: + return HIP_R_32I; + } + } else if (type.code == kDLUInt) { + switch (type.bits) { + case 8: + return HIP_R_8U; + case 32: + return HIP_R_32U; + } + } else if (type.code == kDLFloat) { + switch (type.bits) { + case 16: + return HIP_R_16F; + case 32: + return HIP_R_32F; + case 64: + return HIP_R_64F; + } + } + LOG(FATAL) << "Unsupported hip type"; +} + +inline hipblasDatatype_t GetHipBlasDataType(DLDataType type) { + if (type.code == kDLInt) { + switch (type.bits) { + case 8: + return HIPBLAS_R_8I; + case 32: + return HIPBLAS_R_32I; + } + } else if (type.code == kDLUInt) { + switch (type.bits) { + case 8: + return HIPBLAS_R_8U; + case 32: + return HIPBLAS_R_32U; + } + } else if (type.code == kDLFloat) { + switch (type.bits) { + case 16: + return HIPBLAS_R_16F; + case 32: + return HIPBLAS_R_32F; + case 64: + return HIPBLAS_R_64F; + } + } + LOG(FATAL) << "Unsupported hip type"; +} + +/*! \brief Execute matrix multiply followed by the specified epilogue, using hipBLASLt. */ +void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, + hipblasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, + const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, + bool transb, void* workspace_ptr, size_t workspace_size, + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT); + +} // namespace contrib + +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_HIPBLAS_HIPBLAS_UTILS_H_ diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 561e495a357d..984a2f3323ad 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -360,6 +360,7 @@ TVM_DLL Map GetLibInfo() { {"TVM_DEBUG_WITH_ABI_CHANGE", TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE}, {"TVM_LOG_BEFORE_THROW", TVM_INFO_TVM_LOG_BEFORE_THROW}, {"USE_ROCBLAS", TVM_INFO_USE_ROCBLAS}, + {"USE_HIPBLAS", TVM_INFO_USE_HIPBLAS}, {"USE_ROCM", TVM_INFO_USE_ROCM}, {"USE_RCCL", TVM_INFO_USE_RCCL}, {"USE_RPC", TVM_INFO_USE_RPC}, diff --git a/tests/python/contrib/test_hipblas.py b/tests/python/contrib/test_hipblas.py new file mode 100644 index 000000000000..63a7553704bf --- /dev/null +++ b/tests/python/contrib/test_hipblas.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.contrib import hipblas + + +def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): + n = 1024 + l = 128 + m = 236 + A = te.placeholder((n, l), name="A", dtype=in_dtype) + B = te.placeholder((l, m), name="B", dtype=in_dtype) + C = hipblas.matmul(A, B, dtype=out_dtype) + s = te.create_schedule(C.op) + + def verify(target="rocm"): + if not tvm.get_global_func("tvm.contrib.hipblas.matmul", True): + print("skip because extern function is not available") + return + dev = tvm.rocm(0) + f = tvm.build(s, [A, B, C], target) + a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), np.dot(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)), rtol=rtol + ) + + verify() + + +def roundoff(v, d): + return int(np.floor((v + d - 1) / d) * d) + + +def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): + A = te.placeholder(Ashape, name="A", dtype=in_dtype) + B = te.placeholder(Bshape, name="B", dtype=in_dtype) + C = hipblas.batch_matmul(A, B, dtype=out_dtype) + s = te.create_schedule(C.op) + + dev = tvm.rocm(0) + f = tvm.build(s, [A, B, C], "rocm") + + if "int" in in_dtype: + a = tvm.nd.array(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) + b = tvm.nd.array(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) + else: + a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + + c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), + rtol=rtol, + ) + + +@tvm.testing.requires_rocm +def test_matmul_add(): + verify_matmul_add("float", "float", rtol=1e-3) + verify_matmul_add("float16", "float") + verify_matmul_add("float16", "float16", rtol=1e-2) + verify_matmul_add("int8", "int32") + + +@tvm.testing.requires_rocm +def test_batch_matmul(): + if not tvm.get_global_func("tvm.contrib.hipblas.batch_matmul", True): + print("skip because extern function is not available") + return + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul( + (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + verify_batch_matmul( + (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py new file mode 100644 index 000000000000..f43b83802b81 --- /dev/null +++ b/tests/python/relax/test_codegen_hipblas.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relax +from tvm.relax.backend.contrib.hipblas import partition_for_hipblas +from tvm.relax.testing import get_relax_matmul_module +from tvm.script import relax as R + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +pytestmark = tvm.testing.requires_hipblas.marks() + + +def build_and_run(mod, inputs_np, target, legalize=False): + dev = tvm.device(target, 0) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + return f(*inputs).numpy() + + +def get_result_with_relax_cublas_offload(mod, np_inputs): + mod = partition_for_hipblas(mod) + mod = relax.transform.RunCodegen()(mod) + + return build_and_run(mod, np_inputs, "rocm") + + +def _to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +_vars = { + "a": tvm.tir.expr.Var("a", "int64"), + "b": tvm.tir.expr.Var("b", "int64"), +} + + +_epilogue_table = { + "none": (False, None), + "bias": (True, None), + "relu": (True, R.nn.relu), + "gelu": (True, R.nn.gelu), +} + + +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, epilogue", + [ + # Regular + ((8, 8), (8, 8), False, "none"), + ((_vars["a"], 6), (6, 16), False, "bias"), + # Transposed + ((4, 16), (16, 128), True, "relu"), + ((35, 8), (8, 8), True, "gelu"), + # # 3D x 3D + ((6, 32, 8), (6, 8, 10), False, "bias"), + ((6, 32, 8), (6, 8, 10), True, "none"), + ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"), + # ND x ND + ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"), + # ND x 2D + ((5, 3, 32, 8), (8, 10), False, "none"), + ], +) +@pytest.mark.parametrize( + "in_dtype, out_dtype", + [ + ("float16", "float16"), + ("float32", "float32"), + ], +) +def test_matmul_offload( + x_shape, + y_shape, + transpose_y, + epilogue, + in_dtype, + out_dtype, +): + with_bias, activation = _epilogue_table[epilogue] + var_table = {} + concrete_x_shape = _to_concrete_shape(x_shape, var_table) + concrete_y_shape = _to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(in_dtype) + y = np.random.randn(*concrete_y_shape).astype(in_dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=bias.shape if with_bias else None, + transposed_y=transpose_y, + activation=activation, + ) + + out = get_result_with_relax_cublas_offload(mod, args) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +def test_hipblas_partition_matmul_without_bias(): + # hipBLAS does not handle 2D bias (residual input) + mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32)) + mod = partition_for_hipblas(mod) + + # R.add is still in the main function + assert len(mod["main"].body.blocks[0].bindings) == 2 + + +if __name__ == "__main__": + tvm.testing.main() From 481c2dc85209fa3d104c020b0d8d8e4ce7ed20c1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 23 Aug 2024 07:16:44 +0900 Subject: [PATCH 076/202] [Relax][PyTorch] Add support for torch.tile (#17291) * add test * add support for torch.tile --- .../tvm/relax/frontend/torch/fx_translator.py | 9 ++++ tests/python/relax/test_frontend_from_fx.py | 42 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 093f3ae4cf7a..35131d324076 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -612,6 +612,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _tile(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _cumsum(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1450,6 +1458,7 @@ def create_convert_map(self): "permute": self._permute, "reshape": self._reshape, "split": self._split, + "tile": self._tile, "cumsum": self._cumsum, "chunk": self._chunk, "transpose": self._transpose, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 1a2cc5da6242..6be3e7b23e9d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3126,6 +3126,48 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype= verify_model(Reshape(), input_info, {}, expected1) +def test_tile(): + input_info = [([1, 3], "float32")] + + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((1, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tensor((1, 6), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), input_info, {}, expected1) + verify_model(Tile2(), input_info, {}, expected2) + verify_model(Tile3(), input_info, {}, expected2) + + def test_transpose(): input_info = [([1, 2, 3, 4], "float32")] From 9e865b4b8fdf4cc624e94f8db9e5674c4519db05 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 23 Aug 2024 06:16:56 +0800 Subject: [PATCH 077/202] [Docs] Introduce Relax API and move legacy part to standalone page (#17286) * [Docs] Introduce Relax API and move legacy part to standalone page As the TVM project evolves, the Unity strategy has been the recommended way to use Apache TVM applications. Hence, we are pushing documentation for the Relax API to the forefront and moving the legacy part to a standalone page, which may be removed in the future. * update for ci * update for ci --- docs/arch/index.rst | 9 -- docs/conf.py | 41 +++++++ docs/dev/how_to/relay_add_op.rst | 6 +- docs/index.rst | 20 ++-- docs/reference/api/python/dlight.rst | 22 ++++ docs/reference/api/python/index.rst | 113 +++++++++++++----- docs/reference/api/python/instrument.rst | 22 ++++ docs/reference/api/python/ir.rst | 16 --- docs/reference/api/python/relax/analysis.rst | 22 ++++ .../api/python/relax/block_builder.rst | 21 ++++ docs/reference/api/python/relax/frontend.rst | 48 ++++++++ docs/reference/api/python/relax/op.rst | 72 +++++++++++ docs/reference/api/python/relax/relax.rst | 23 ++++ docs/reference/api/python/relax/transform.rst | 24 ++++ docs/reference/api/python/relay/transform.rst | 1 + docs/reference/api/python/runtime/disco.rst | 22 ++++ .../api/python/{ => runtime}/ndarray.rst | 6 - .../api/python/runtime/profiling.rst | 21 ++++ .../{vta/index.rst => runtime/relax_vm.rst} | 30 +---- .../api/python/{ => runtime}/runtime.rst | 3 - docs/reference/api/python/tir/analysis.rst | 21 ++++ docs/reference/api/python/tir/schedule.rst | 22 ++++ .../reference/api/python/tir/stmt_functor.rst | 21 ++++ docs/reference/api/python/tir/tir.rst | 23 ++++ .../api/python/{tir.rst => tir/transform.rst} | 27 ----- docs/reference/api/python/transform.rst | 22 ++++ docs/{arch => reference}/security.rst | 0 python/tvm/driver/build_module.py | 4 +- python/tvm/relax/op/create.py | 2 +- python/tvm/relax/transform/transform.py | 27 ++--- python/tvm/runtime/profiling/__init__.py | 3 +- python/tvm/target/__init__.py | 2 +- python/tvm/te/operation.py | 2 +- python/tvm/tir/buffer.py | 2 +- 34 files changed, 569 insertions(+), 151 deletions(-) create mode 100644 docs/reference/api/python/dlight.rst create mode 100644 docs/reference/api/python/instrument.rst create mode 100644 docs/reference/api/python/relax/analysis.rst create mode 100644 docs/reference/api/python/relax/block_builder.rst create mode 100644 docs/reference/api/python/relax/frontend.rst create mode 100644 docs/reference/api/python/relax/op.rst create mode 100644 docs/reference/api/python/relax/relax.rst create mode 100644 docs/reference/api/python/relax/transform.rst create mode 100644 docs/reference/api/python/runtime/disco.rst rename docs/reference/api/python/{ => runtime}/ndarray.rst (88%) create mode 100644 docs/reference/api/python/runtime/profiling.rst rename docs/reference/api/python/{vta/index.rst => runtime/relax_vm.rst} (61%) rename docs/reference/api/python/{ => runtime}/runtime.rst (95%) create mode 100644 docs/reference/api/python/tir/analysis.rst create mode 100644 docs/reference/api/python/tir/schedule.rst create mode 100644 docs/reference/api/python/tir/stmt_functor.rst create mode 100644 docs/reference/api/python/tir/tir.rst rename docs/reference/api/python/{tir.rst => tir/transform.rst} (68%) create mode 100644 docs/reference/api/python/transform.rst rename docs/{arch => reference}/security.rst (100%) diff --git a/docs/arch/index.rst b/docs/arch/index.rst index b84afeea2818..17884a774253 100644 --- a/docs/arch/index.rst +++ b/docs/arch/index.rst @@ -408,15 +408,6 @@ Frontends ingest models from different frameworks into the TVM stack. frontend/tensorflow - -Security ---------- -.. toctree:: - :maxdepth: 1 - - security - - microTVM -------- .. toctree:: diff --git a/docs/conf.py b/docs/conf.py index c3472c15de91..1c5c5cb5d602 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -39,6 +39,7 @@ import re import sys from textwrap import dedent, indent +from typing import List from unittest.mock import patch # If extensions (or modules to document with autodoc) are in another directory, @@ -718,10 +719,50 @@ def update_alias_docstring(name, obj, lines): lines.append(".. rubric:: Alias of %s:`%s.%s`" % (obj_type, amod, target_name)) +tvm_class_name_rewrite_map = { + "tvm.tir": ["Var", "Call"], + "tvm.relax": ["Var", "Call"], + "tvm.relax.frontend.nn": ["Module"], +} + + +def distinguish_class_name(name: str, lines: List[str]): + """Distinguish the docstring of type annotations. + + In the whole TVM, there are many classes with the same name but in different modules, + e.g. ``tir.Var``, ``relax.Var``. This function is used to distinguish them in the docstring, + by adding the module name as prefix. + + To be specific, this function will check the current object name, and if it in the specific + module with specific name, it will add the module name as prefix to the class name to prevent + the confusion. Further, we only add the prefix to those standalone class name, but skip + the pattern of `xx.Var`, `Var.xx` and `xx.Var.xx`. + + Parameters + ---------- + name : str + The full name of the object in the doc. + + lines : list + The docstring lines, need to be modified inplace. + """ + remap = {} + for module_name in tvm_class_name_rewrite_map: + if name.startswith(module_name): + short_name = module_name[4:] if module_name.startswith("tvm.") else module_name + for class_name in tvm_class_name_rewrite_map[module_name]: + remap.update({class_name: f"{short_name}.{class_name}"}) + + for k, v in remap.items(): + for i in range(len(lines)): + lines[i] = re.sub(rf"(?`, :ref:`TVM's operator inventory (topi) ` and looking at the example cumulative sum and product implementations found in `python/tvm/topi/scan.py`_ and the gpu versions in -`python/tvm/topi/cuda/scan.py`_. In the case of our cumulative sum and product -operations we write things directly in :ref:`TIR ` which is the -representation where tensor expressions and topi will lower into. +`python/tvm/topi/cuda/scan.py`_. .. _python/tvm/topi/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/scan.py .. _python/tvm/topi/cuda/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scan.py -5. Hooking up Compute and Strategy with Relay +1. Hooking up Compute and Strategy with Relay --------------------------------------------- After you have implemented your compute function we now need to glue it to our diff --git a/docs/index.rst b/docs/index.rst index 7f13101f741e..2b7896c652d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,23 +52,29 @@ driving its costs down. .. toctree:: :maxdepth: 1 - :caption: Architecture Guide + :caption: API Reference - arch/index + reference/api/python/index + reference/api/links .. toctree:: :maxdepth: 1 - :caption: Topic Guides + :caption: Legacy + reference/langref/index + arch/index topic/microtvm/index topic/vta/index .. toctree:: :maxdepth: 1 - :caption: Reference Guide + :caption: About - reference/langref/index - reference/api/python/index - reference/api/links reference/publications + reference/security + +.. toctree:: + :maxdepth: 1 + :caption: Index + genindex diff --git a/docs/reference/api/python/dlight.rst b/docs/reference/api/python/dlight.rst new file mode 100644 index 000000000000..37859ed790f4 --- /dev/null +++ b/docs/reference/api/python/dlight.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.dlight +---------- +.. automodule:: tvm.dlight + :members: + :imported-members: diff --git a/docs/reference/api/python/index.rst b/docs/reference/api/python/index.rst index 5dc1ed806dfd..e64ea304cbee 100644 --- a/docs/reference/api/python/index.rst +++ b/docs/reference/api/python/index.rst @@ -18,34 +18,89 @@ Python API ========== +.. toctree:: + :maxdepth: 1 + :caption: tvm + + error + ir + instrument + transform + target + driver + +.. toctree:: + :maxdepth: 1 + :caption: tvm.runtime + + runtime/runtime + runtime/ndarray + runtime/relax_vm + runtime/disco + runtime/profiling + +.. toctree:: + :maxdepth: 1 + :caption: tvm.relax + + relax/relax + relax/analysis + relax/block_builder + relax/frontend + relax/op + relax/transform + +.. toctree:: + :maxdepth: 1 + :caption: tvm.tir + + tir/tir + tir/analysis + tir/schedule + tir/stmt_functor + tir/transform + +.. toctree:: + :maxdepth: 1 + :caption: tvm.te + + te + topi + +.. toctree:: + :maxdepth: 1 + :caption: tvm.meta_schedule + + meta_schedule + +.. toctree:: + :maxdepth: 1 + :caption: tvm.dlight + + dlight + +.. toctree:: + :maxdepth: 1 + :caption: Misc + + rpc + contrib .. toctree:: - :maxdepth: 2 - - runtime - ndarray - error - ir - target - tir - te - driver - relay/index - relay/frontend - relay/nn - relay/vision - relay/image - relay/transform - relay/analysis - relay/backend - relay/dataflow_pattern - relay/testing - autotvm - auto_scheduler - meta_schedule - rpc - micro - contrib - graph_executor - topi - vta/index + :maxdepth: 1 + :caption: Legacy + + relay/index + relay/frontend + relay/nn + relay/vision + relay/image + relay/transform + relay/analysis + relay/backend + relay/dataflow_pattern + relay/testing + autotvm + auto_scheduler + micro + graph_executor diff --git a/docs/reference/api/python/instrument.rst b/docs/reference/api/python/instrument.rst new file mode 100644 index 000000000000..270a19690b9e --- /dev/null +++ b/docs/reference/api/python/instrument.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.instrument +-------------- +.. automodule:: tvm.instrument + :members: + :imported-members: diff --git a/docs/reference/api/python/ir.rst b/docs/reference/api/python/ir.rst index e7fb3c114689..1f0dc0c5e23c 100644 --- a/docs/reference/api/python/ir.rst +++ b/docs/reference/api/python/ir.rst @@ -21,19 +21,3 @@ tvm.ir :members: :imported-members: :autosummary: - - -tvm.instrument --------------- -.. automodule:: tvm.instrument - :members: - :imported-members: - :autosummary: - - -tvm.transform -------------- -.. automodule:: tvm.transform - :members: - :imported-members: - :autosummary: diff --git a/docs/reference/api/python/relax/analysis.rst b/docs/reference/api/python/relax/analysis.rst new file mode 100644 index 000000000000..b6598b54574e --- /dev/null +++ b/docs/reference/api/python/relax/analysis.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relax.analysis +------------------ +.. automodule:: tvm.relax.analysis + :members: + :imported-members: diff --git a/docs/reference/api/python/relax/block_builder.rst b/docs/reference/api/python/relax/block_builder.rst new file mode 100644 index 000000000000..a1c2a7c4354b --- /dev/null +++ b/docs/reference/api/python/relax/block_builder.rst @@ -0,0 +1,21 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relax.block_builder +----------------------- +.. automodule:: tvm.relax.block_builder + :members: diff --git a/docs/reference/api/python/relax/frontend.rst b/docs/reference/api/python/relax/frontend.rst new file mode 100644 index 000000000000..c037f323ed1a --- /dev/null +++ b/docs/reference/api/python/relax/frontend.rst @@ -0,0 +1,48 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relax.frontend +------------------ +.. automodule:: tvm.relax.frontend + :members: + :imported-members: + +tvm.relax.frontend.nn +********************* +.. automodule:: tvm.relax.frontend.nn + :members: + :imported-members: + :exclude-members: BlockBuilder + :noindex: + +tvm.relax.frontend.onnx +*********************** +.. automodule:: tvm.relax.frontend.onnx + :members: + :imported-members: + +tvm.relax.frontend.stablehlo +**************************** +.. automodule:: tvm.relax.frontend.stablehlo + :members: + :imported-members: + +tvm.relax.frontend.torch +************************ +.. automodule:: tvm.relax.frontend.torch + :members: + :imported-members: diff --git a/docs/reference/api/python/relax/op.rst b/docs/reference/api/python/relax/op.rst new file mode 100644 index 000000000000..21f638442a84 --- /dev/null +++ b/docs/reference/api/python/relax/op.rst @@ -0,0 +1,72 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relax.op +------------ + +tvm.relax.op +************ +.. automodule:: tvm.relax.op + :members: + :imported-members: + +tvm.relax.op.nn +*************** +.. automodule:: tvm.relax.op.nn + :members: + :imported-members: + +tvm.relax.op.builtin +******************** +.. automodule:: tvm.relax.op.builtin + :members: + :imported-members: + +tvm.relax.op.ccl +**************** +.. automodule:: tvm.relax.op.ccl + :members: + :imported-members: + +tvm.relax.op.distributed +************************ +.. automodule:: tvm.relax.op.distributed + :members: + :imported-members: + +tvm.relax.op.grad +***************** +.. automodule:: tvm.relax.op.grad + :members: + :imported-members: + +tvm.relax.op.image +****************** +.. automodule:: tvm.relax.op.image + :members: + :imported-members: + +tvm.relax.op.memory +******************* +.. automodule:: tvm.relax.op.memory + :members: + :imported-members: + +tvm.relax.op.op_attrs +********************* +.. automodule:: tvm.relax.op.op_attrs + :members: diff --git a/docs/reference/api/python/relax/relax.rst b/docs/reference/api/python/relax/relax.rst new file mode 100644 index 000000000000..4df1f1279b59 --- /dev/null +++ b/docs/reference/api/python/relax/relax.rst @@ -0,0 +1,23 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relax +--------- +.. automodule:: tvm.relax + :members: + :imported-members: + :exclude-members: BlockBuilder, Span, GlobalVar, SourceName, TupleType, Type, FuncType diff --git a/docs/reference/api/python/relax/transform.rst b/docs/reference/api/python/relax/transform.rst new file mode 100644 index 000000000000..dcb41e80fd67 --- /dev/null +++ b/docs/reference/api/python/relax/transform.rst @@ -0,0 +1,24 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _api-relax-transformation: + +tvm.relax.transform +------------------- +.. automodule:: tvm.relax.transform + :members: + :imported-members: diff --git a/docs/reference/api/python/relay/transform.rst b/docs/reference/api/python/relay/transform.rst index c66904d8bcba..4a8747606eb2 100644 --- a/docs/reference/api/python/relay/transform.rst +++ b/docs/reference/api/python/relay/transform.rst @@ -22,3 +22,4 @@ tvm.relay.transform :members: :imported-members: :autosummary: + :exclude-members: FunctionPass diff --git a/docs/reference/api/python/runtime/disco.rst b/docs/reference/api/python/runtime/disco.rst new file mode 100644 index 000000000000..6a9b60394732 --- /dev/null +++ b/docs/reference/api/python/runtime/disco.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.runtime.disco +----------------- +.. automodule:: tvm.runtime.disco + :members: + :imported-members: diff --git a/docs/reference/api/python/ndarray.rst b/docs/reference/api/python/runtime/ndarray.rst similarity index 88% rename from docs/reference/api/python/ndarray.rst rename to docs/reference/api/python/runtime/ndarray.rst index aa828905ca21..8c794f04b193 100644 --- a/docs/reference/api/python/ndarray.rst +++ b/docs/reference/api/python/runtime/ndarray.rst @@ -18,10 +18,4 @@ tvm.runtime.ndarray ------------------- .. automodule:: tvm.runtime.ndarray - -.. autoclass:: tvm.nd.NDArray :members: - :inherited-members: - -.. autofunction:: tvm.nd.array -.. autofunction:: tvm.nd.empty diff --git a/docs/reference/api/python/runtime/profiling.rst b/docs/reference/api/python/runtime/profiling.rst new file mode 100644 index 000000000000..d26f00af90c6 --- /dev/null +++ b/docs/reference/api/python/runtime/profiling.rst @@ -0,0 +1,21 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.runtime.profiling +--------------------- +.. automodule:: tvm.runtime.profiling + :members: diff --git a/docs/reference/api/python/vta/index.rst b/docs/reference/api/python/runtime/relax_vm.rst similarity index 61% rename from docs/reference/api/python/vta/index.rst rename to docs/reference/api/python/runtime/relax_vm.rst index 479b8394f0cb..75afcb7939ab 100644 --- a/docs/reference/api/python/vta/index.rst +++ b/docs/reference/api/python/runtime/relax_vm.rst @@ -15,31 +15,7 @@ specific language governing permissions and limitations under the License. -vta -=== - -This document contains the python API to VTA compiler toolchain. - -.. automodule:: vta - -Hardware Information +tvm.runtime.relax_vm -------------------- - -.. autofunction:: vta.Environment -.. autofunction:: vta.get_env - -RPC Utilities -------------- - -.. autofunction:: vta.reconfig_runtime -.. autofunction:: vta.program_fpga - - -Compiler API ------------- -We program VTA using TVM, so the compiler API in vta package -is only a thin wrapper to provide VTA specific extensions. - -.. autofunction:: vta.build_config -.. autofunction:: vta.build -.. autofunction:: vta.lower +.. automodule:: tvm.runtime.relax_vm + :members: diff --git a/docs/reference/api/python/runtime.rst b/docs/reference/api/python/runtime/runtime.rst similarity index 95% rename from docs/reference/api/python/runtime.rst rename to docs/reference/api/python/runtime/runtime.rst index c51a2d452065..4dd9d9653369 100644 --- a/docs/reference/api/python/runtime.rst +++ b/docs/reference/api/python/runtime/runtime.rst @@ -17,9 +17,6 @@ tvm.runtime ----------- - .. automodule:: tvm.runtime :members: - :imported-members: :exclude-members: NDArray - :autosummary: diff --git a/docs/reference/api/python/tir/analysis.rst b/docs/reference/api/python/tir/analysis.rst new file mode 100644 index 000000000000..aa777358bcf2 --- /dev/null +++ b/docs/reference/api/python/tir/analysis.rst @@ -0,0 +1,21 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.tir.analysis +---------------- +.. automodule:: tvm.tir.analysis.analysis + :members: diff --git a/docs/reference/api/python/tir/schedule.rst b/docs/reference/api/python/tir/schedule.rst new file mode 100644 index 000000000000..17e4a4593a47 --- /dev/null +++ b/docs/reference/api/python/tir/schedule.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.tir.schedule +----------------- +.. automodule:: tvm.tir.schedule + :members: + :imported-members: diff --git a/docs/reference/api/python/tir/stmt_functor.rst b/docs/reference/api/python/tir/stmt_functor.rst new file mode 100644 index 000000000000..3b6c9bb64a89 --- /dev/null +++ b/docs/reference/api/python/tir/stmt_functor.rst @@ -0,0 +1,21 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.tir.stmt_functor +-------------------- +.. automodule:: tvm.tir.stmt_functor + :members: diff --git a/docs/reference/api/python/tir/tir.rst b/docs/reference/api/python/tir/tir.rst new file mode 100644 index 000000000000..3f82fe8261ac --- /dev/null +++ b/docs/reference/api/python/tir/tir.rst @@ -0,0 +1,23 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.tir +------- +.. automodule:: tvm.tir + :members: + :imported-members: + :exclude-members: PrimExpr, const, StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/docs/reference/api/python/tir.rst b/docs/reference/api/python/tir/transform.rst similarity index 68% rename from docs/reference/api/python/tir.rst rename to docs/reference/api/python/tir/transform.rst index 2152be69ea6f..8ce641b6d3f6 100644 --- a/docs/reference/api/python/tir.rst +++ b/docs/reference/api/python/tir/transform.rst @@ -15,36 +15,9 @@ specific language governing permissions and limitations under the License. -.. _api-python-tir: - -tvm.tir -------- -.. automodule:: tvm.tir - :members: - :imported-members: - :exclude-members: PrimExpr, const - :autosummary: - tvm.tir.transform ----------------- .. automodule:: tvm.tir.transform :members: :imported-members: - :autosummary: - - -tvm.tir.analysis ----------------- -.. automodule:: tvm.tir.analysis - :members: - :imported-members: - :noindex: Buffer, Stmt - :autosummary: - - -tvm.tir.stmt_functor --------------------- -.. automodule:: tvm.tir.stmt_functor - :members: - :autosummary: diff --git a/docs/reference/api/python/transform.rst b/docs/reference/api/python/transform.rst new file mode 100644 index 000000000000..d200dfdd1139 --- /dev/null +++ b/docs/reference/api/python/transform.rst @@ -0,0 +1,22 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.transform +------------- +.. automodule:: tvm.transform + :members: + :imported-members: diff --git a/docs/arch/security.rst b/docs/reference/security.rst similarity index 100% rename from docs/arch/security.rst rename to docs/reference/security.rst diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index c332062b37b9..08af27e32f04 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -105,7 +105,7 @@ def lower( inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built - args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, tir.Var]]] The argument lists to the function for TE schedule. It should be None if we want to lower TensorIR. @@ -156,7 +156,7 @@ def build( inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built - args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, tir.Var]]] The argument lists to the function. target : Optional[Union[str, Target]] diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 8fd3b2cde1e7..092d79a74dc4 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -241,7 +241,7 @@ def tril(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr: return _ffi_api.tril(x, k) # type: ignore -def triu(x: Expr, k: [int, PrimExpr, Expr] = 0) -> Expr: +def triu(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr: """Return the upper triangular part of a matrix or a batch of matrices. Parameters diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2546284625e9..95649f331f33 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -391,8 +391,8 @@ def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass: Note: ConvertToDataflow may need to be called first. - Params - ------ + Parameters + ---------- min_size: int The minimum number of consecutive dataflow bindings the pass needs to extract a new block. @@ -647,13 +647,8 @@ def BindParams( func_name: str The function name to be bound - params : Dict[ - Union[str,relax.Var], - Union[tvm.runtime.NDArray, np.ndarray], - ] - - The map from parameter or parameter name to constant - tensors. + params: Dict[Union[str,relax.Var], Union[tvm.runtime.NDArray, np.ndarray]] + The map from parameter or parameter name to constant tensors. Returns ------- @@ -994,16 +989,16 @@ def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> tvm Indicates how the parameter transformation function will be produced - `False` (default): A separate parameter transformation function will be - produced for each function with the `"num_input"` attribute. + produced for each function with the `"num_input"` attribute. - `True`: A single parameter transformation function will be produced, - containing the preprocessing steps common across all functions with - the `"num_input"` attribute. + containing the preprocessing steps common across all functions with + the `"num_input"` attribute. - List[str]: A single parameter transformation function will be produced, - containing the preprocessing steps common across each function whose - name is in the list. Passing a list of all functions with the `"num_input"` - attribute or an empty list is equivalent to passing `True`. + containing the preprocessing steps common across each function whose + name is in the list. Passing a list of all functions with the `"num_input"` + attribute or an empty list is equivalent to passing `True`. Returns ------- @@ -1219,7 +1214,7 @@ def MetaScheduleTuneIRMod( maximum number of trials per task op_names: Optional[List[str]] A list of operator names to specify which op to tune. When it is None, all operators - are tuned. + are tuned. Returns ------- diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 347d8b9f94f1..23ce5476f5b0 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -230,6 +230,7 @@ def profile_function(mod, dev, collectors, func_name=None, warmup_iters=10): ------- .. code-block: python + f = tvm.build(my_func, target="llvm", name="my_func") prof = tvm.runtime.profiling.profile_function( f, @@ -247,7 +248,7 @@ def profile_function(mod, dev, collectors, func_name=None, warmup_iters=10): Device to run the function on. collectors: List[MetricCollector] - :py:class:`MetricCollector`s which will collect performance information. + :py:class:`MetricCollector` which will collect performance information. func_name: Optional[str] Name of the function in `mod` to profile. Defaults to the `entry_name` of `mod`. warmup_iters: int diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 78a7e0160db7..14bd4753d400 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -51,7 +51,7 @@ Build TVM system library module. System lib is a global module that contains self registered functions in program startup. User can get the module using - :any:`tvm.runtime.system_lib`. + `tvm.runtime.system_lib`. It is useful in environments where dynamic loading api like dlopen is banned. The system lib will be available as long as the result code is linked by the program. diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 64a282dcf755..63a3ecd57b1c 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -459,7 +459,7 @@ def var(name="tindex", dtype="int32", span=None): Returns ------- - var : Var + var : tir.Var The result symbolic variable. """ return tvm.tir.Var(name, dtype, span) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 501d13b17e3d..1109cc3d66d6 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -262,7 +262,7 @@ def decl_buffer( name : str, optional The name of the buffer. - data : Var, optional + data : tir.Var, optional The data pointer in the buffer. strides: array of Expr From e1da4651df0afcea740f53f590aa42450f3795ed Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 23 Aug 2024 20:05:55 +0800 Subject: [PATCH 078/202] [Doc] IRModule (#17298) --- docs/get_started/tutorials/ir_module.py | 281 ++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 282 insertions(+) create mode 100644 docs/get_started/tutorials/ir_module.py diff --git a/docs/get_started/tutorials/ir_module.py b/docs/get_started/tutorials/ir_module.py new file mode 100644 index 000000000000..f813333bafc3 --- /dev/null +++ b/docs/get_started/tutorials/ir_module.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _ir_module: + +IRModule +======== +This tutorial presents the core abstraction of Apache TVM Unity, the IRModule. +The IRModule encompasses the **entirety** of the ML models, incorporating the +computational graph, tensor programs, and potential calls to external libraries. + +.. contents:: Table of Contents + :local: + :depth: 1 +""" + +import numpy as np +import tvm +from tvm import relax + +###################################################################### +# Create IRModule +# --------------- +# IRModules can be initialized in various ways. We demonstrate a few of them +# below. + +import torch +from torch import fx, nn +from tvm.relax.frontend.torch import from_fx + +###################################################################### +# Import from existing models +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# The most common way to initialize an IRModule is to import from an existing +# model. Apache TVM Unity accommodates imports from a range of frameworks, +# such as PyTorch and ONNX. This tutorial solely demonstrates the import process +# from PyTorch. + + +# Create a dummy model +class TorchModel(nn.Module): + def __init__(self): + super(TorchModel, self).__init__() + self.fc1 = nn.Linear(784, 256) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +# Give the input shape and data type +input_info = [((1, 784), "float32")] + +# Convert the model to IRModule +with torch.no_grad(): + torch_fx_model = fx.symbolic_trace(TorchModel()) + mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + +mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch) +# Print the IRModule +mod_from_torch.show() + +###################################################################### +# Write with Relax NN Module +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users +# write the IRModule directly. + +from tvm.relax.frontend import nn + + +class RelaxModel(nn.Module): + def __init__(self): + super(RelaxModel, self).__init__() + self.fc1 = nn.Linear(784, 256) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +mod_from_relax, params_from_relax = RelaxModel().export_tvm( + {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}} +) +mod_from_relax.show() + +###################################################################### +# Create via TVMScript +# ~~~~~~~~~~~~~~~~~~~~ +# TVMScript is a Python-based DSL for IRModules. We are able to +# directly output the IRModule in the TVMScript syntax, or alternatively, +# parse the TVMScript to obtain an IRModule. + +from tvm.script import ir as I +from tvm.script import relax as R + + +@I.ir_module +class TVMScriptModule: + @R.function + def main( + x: R.Tensor((1, 784), dtype="float32"), + fc1_weight: R.Tensor((256, 784), dtype="float32"), + fc1_bias: R.Tensor((256,), dtype="float32"), + fc2_weight: R.Tensor((10, 256), dtype="float32"), + fc2_bias: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((1, 10), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + permute_dims = R.permute_dims(fc1_weight, axes=None) + matmul = R.matmul(x, permute_dims, out_dtype="void") + add = R.add(matmul, fc1_bias) + relu = R.nn.relu(add) + permute_dims1 = R.permute_dims(fc2_weight, axes=None) + matmul1 = R.matmul(relu, permute_dims1, out_dtype="void") + add1 = R.add(matmul1, fc2_bias) + gv = add1 + R.output(gv) + return gv + + +mod_from_script = TVMScriptModule +mod_from_script.show() + +###################################################################### +# Attributes of an IRModule +# ------------------------- +# An IRModule is a collection of functions, indexed by GlobalVars. + +mod = mod_from_torch +print(mod.get_global_vars()) + +###################################################################### +# We can access the functions in the IRModule by indexing with the GlobalVars +# or their names + +# index by global var name +print(mod["main"]) +# index by global var, and checking they are the same function +(gv,) = mod.get_global_vars() +assert mod[gv] == mod["main"] + +###################################################################### +# Transformations on IRModules +# ---------------------------- +# Transformations are the import component of Apache TVM Unity. One transformation +# takes in an IRModule and outputs another IRModule. We can apply a sequence of +# transformations to an IRModule to obtain a new IRModule. That is the common way to +# optimize a model. +# +# In this getting started tutorial, we only demonstrate how to apply transformations +# to an IRModule. For details of each transformation, please refer to the +# :ref:`Transformation API Reference ` + +###################################################################### +# We first apply **LegalizeOps** transformation to the IRModule. This transformation +# will convert the Relax module into a mixed stage, with both Relax and TensorIR function +# within the same module. Meanwhile, the Relax operators will be converted into ``call_tir``. + +mod = mod_from_torch +mod = relax.transform.LegalizeOps()(mod) +mod.show() + +###################################################################### +# After the transformation, there are much more functions inside the module. Let's print +# the global vars again. + +print(mod.get_global_vars()) + +###################################################################### +# Next, Apache TVM Unity provides a set of default transformation pipelines for users, +# to simplify the transformation process. We can then apply the default pipeline to the module. +# The default **zero** pipeline contains very fundamental transformations, including: +# +# - **LegalizeOps**: This transform converts the Relax operators into `call_tir` functions +# with the corresponding TensorIR Functions. After this transform, the IRModule will +# contain both Relax functions and TensorIR functions. +# - **AnnotateTIROpPattern**: This transform annotates the pattern of the TensorIR functions, +# preparing them for subsequent operator fusion. +# - **FoldConstant**: This pass performs constant folding, optimizing operations +# involving constants. +# - **FuseOps and FuseTIR**: These two passes work together to fuse operators based on the +# patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform +# both Relax functions and TensorIR functions. +# +# .. note:: +# +# Here, we have applied **LegalizeOps** twice in the flow. The second time is useless but +# harmless. +# +# Every passes can be duplicated in the flow, since we ensure the passes can handle all legal +# IRModule inputs. This design can help users to construct their own pipeline. + +mod = relax.get_pipeline("zero")(mod) +mod.show() + +###################################################################### +# Deploy the IRModule Universally +# ------------------------------- +# After the optimization, we can compile the model into a TVM runtime module. +# Notably, Apache TVM Unity provides the ability of universal deployment, which means +# we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging +# backends. +# +# Deploy on CPU +# ~~~~~~~~~~~~~ +# We can deploy the IRModule on CPU by specifying the target as ``llvm``. + +exec = relax.build(mod, target="llvm") +dev = tvm.cpu() +vm = relax.VirtualMachine(exec, dev) + +raw_data = np.random.rand(1, 784).astype("float32") +data = tvm.nd.array(raw_data, dev) +cpu_out = vm["main"](data, *params_from_torch["main"]).numpy() +print(cpu_out) + +###################################################################### +# Deploy on GPU +# ~~~~~~~~~~~~~ +# Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires +# programs containing extra information, such as the thread bindings and shared memory +# allocations. We need a further transformation to generate the GPU programs. +# +# We use ``DLight`` to generate the GPU programs. In this tutorial, we won't go into +# the details of ``DLight``. +# + +from tvm import dlight as dl + +with tvm.target.Target("cuda"): + gpu_mod = dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.Fallback(), + )(mod) + +###################################################################### +# Now we can compile the IRModule on GPU, the similar way as we did on CPU. + +exec = relax.build(gpu_mod, target="cuda") +dev = tvm.device("cuda", 0) +vm = relax.VirtualMachine(exec, dev) +# Need to allocate data and params on GPU device +data = tvm.nd.array(raw_data, dev) +gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]] +gpu_out = vm["main"](data, *gpu_params).numpy() +print(gpu_out) + +# Check the correctness of the results +assert np.allclose(cpu_out, gpu_out, atol=1e-3) + +###################################################################### +# Deploy on Other Backends +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# Apache TVM Unity also supports other backends, such as different kinds of GPUs +# (Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other +# emerging backends (e.g., WebAssembly). The deployment process is similar to the +# GPU backend. diff --git a/docs/index.rst b/docs/index.rst index 2b7896c652d0..2fc8ce7980da 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ driving its costs down. install/index get_started/tutorials/quick_start + get_started/tutorials/ir_module contribute/index .. toctree:: From 15180082626d01ccad0648a088d11a29e0678790 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:49:33 -0700 Subject: [PATCH 079/202] [Web] Add TVMArgBool to ArgTypeCode (#17251) --- web/src/ctypes.ts | 5 +++-- web/src/runtime.ts | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index cb2a0e1097b4..c4941f07d57a 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -171,7 +171,7 @@ export type FTVMBackendPackedCFunc = ( /** * int TVMObjectFree(TVMObjectHandle obj); */ - export type FTVMObjectFree = (obj: Pointer) => number; +export type FTVMObjectFree = (obj: Pointer) => number; /** * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); @@ -252,5 +252,6 @@ export const enum ArgTypeCode { TVMStr = 11, TVMBytes = 12, TVMNDArrayHandle = 13, - TVMObjectRValueRefArg = 14 + TVMObjectRValueRefArg = 14, + TVMArgBool = 15, } diff --git a/web/src/runtime.ts b/web/src/runtime.ts index e446c4dc4dfb..600a9b857f03 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -2474,6 +2474,7 @@ export class Instance implements Disposable { switch (tcode) { case ArgTypeCode.Int: case ArgTypeCode.UInt: + case ArgTypeCode.TVMArgBool: return this.memory.loadI64(rvaluePtr); case ArgTypeCode.Float: return this.memory.loadF64(rvaluePtr); From ca22bad77d66adeba7ce9e61dcfd6f39c40f0dc0 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 24 Aug 2024 01:51:42 +0800 Subject: [PATCH 080/202] [Doc] Overview (#17296) Overview page for Apache TVM. --- docs/get_started/overview.rst | 66 +++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 67 insertions(+) create mode 100644 docs/get_started/overview.rst diff --git a/docs/get_started/overview.rst b/docs/get_started/overview.rst new file mode 100644 index 000000000000..5931837d16c1 --- /dev/null +++ b/docs/get_started/overview.rst @@ -0,0 +1,66 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Overview +======== + +Apache TVM is a machine learning compilation framework, following the principle of **Python-first development** +and **universal deployment**. It takes in pre-trained machine learning models, +compiles and generates deployable modules that can be embedded and run everywhere. Apache TVM also enables customizing optimization processes to introduce new optimizations, libraries, codegen +and more. + +Key Principle +------------- + +- **Python-first**: the optimization process is fully customizable in Python. + It is easy to customize the optimization pipeline without recompiling the TVM stack. +- **Composable**: the optimization process is composable. It is easy to compose + new optimization passes, libraries and codegen to the existing pipeline. + +Key Goals +--------- + +- **Optimize** performance of ML workloads, composing libraries and codegen. +- **Deploy** ML workloads to a diverse set of new environments, including new runtime and new hardware. +- **Continuously improve and customize** ML deployment pipeline in Python by quickly customizing library dispatching, + bringing in customized operators and code generation. + +Key Flow +-------- + +Here is a typical flow of using TVM to deploy a machine learning model. For a runnable example, +please refer to :ref:`quick_start` + +1. **Import/construct an ML model** + + TVM supports importing models from various frameworks, such as PyTorch, TensorFlow for generic ML models. Meanwhile, we can create models directly using Relax frontend for scenarios of large language models. + +2. **Perform composable optimization** transformations via ``pipelines`` + + The pipeline encapsulates a collection of transformations to achieve two goals: + + - **Graph Optimizations**: such as operator fusion, and layout rewrites. + - **Tensor Program Optimization**: Map the operators to low-level implementations (both library or codegen) + + .. note:: + + The two are goals but not the stages of the pipeline. The two optimizations are performed + **at the same level**, or separately in two stages. + +3. **Build and universal deploy** + + Apache TVM aims to provide a universal deployment solution to bring machine learning everywhere with every language with minimum runtime support. TVM runtime can work in non-Python environments, so it works on mobile, edge devices or even bare metal devices. Additionally, TVM runtime comes with native data structures, and can also have zero copy exchange with the existing ecosystem (PyTorch, TensorFlow, TensorRT, etc.) using DLPack support. diff --git a/docs/index.rst b/docs/index.rst index 2fc8ce7980da..07022cdef7ae 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ driving its costs down. :maxdepth: 1 :caption: Getting Started + get_started/overview install/index get_started/tutorials/quick_start get_started/tutorials/ir_module From 541f9c280c567b63630229bc03855d43fc6811af Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 24 Aug 2024 08:44:04 -0700 Subject: [PATCH 081/202] [Rocm] Fix non-standard rocm path (#17295) * [Rocm] Fix non-standard rocm path --- python/tvm/contrib/rocm.py | 16 ++++++++++++---- src/runtime/rocm/rocm_device_api.cc | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index 119a2c588c99..f3427463b3e0 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -136,8 +136,10 @@ def callback_rocm_bitcode_path(rocdl_dir=None): # seems link order matters. if rocdl_dir is None: - if exists("/opt/rocm/amdgcn/bitcode/"): - rocdl_dir = "/opt/rocm/amdgcn/bitcode/" # starting with rocm 3.9 + rocm_path = find_rocm_path() + amdgcn_path = f"{rocm_path}/amdgcn/bitcode/" + if exists(amdgcn_path): + rocdl_dir = amdgcn_path # starting with rocm 3.9 else: rocdl_dir = "/opt/rocm/lib/" # until rocm 3.8 @@ -226,7 +228,7 @@ def have_matrixcore(compute_version=None): @tvm._ffi.register_func("tvm_callback_rocm_get_arch") -def get_rocm_arch(rocm_path="/opt/rocm"): +def get_rocm_arch(rocm_path=None): """Utility function to get the AMD GPU architecture Parameters @@ -239,9 +241,15 @@ def get_rocm_arch(rocm_path="/opt/rocm"): gpu_arch : str The AMD GPU architecture """ + if rocm_path is None: + try: + rocm_path = find_rocm_path() + except RuntimeError: + rocm_path = None + gpu_arch = "gfx900" # check if rocm is installed - if not os.path.exists(rocm_path): + if rocm_path is None or not os.path.exists(rocm_path): print("ROCm not detected, using default gfx900") return gpu_arch try: diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c37e9fada5b2..ebfd312595a3 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -139,7 +139,8 @@ class ROCMDeviceAPI final : public DeviceAPI { case kAvailableGlobalMemory: // Not currently implemented. - break; + *rv = nullptr; + return; } *rv = value; } From 47e964a5973575c1e270c62b0fd785135e1b5bca Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 26 Aug 2024 04:27:47 -0700 Subject: [PATCH 082/202] [Codegen][WebGPU] LetNode common subexpr override (#17302) This PR overrides the WebGPU codegen function of `tir::LetNode` to adapt to the recent LetNode common subexpression changes. Co-authored-by: Ruihang Lai --- src/target/source/codegen_webgpu.cc | 21 +++++++++++++++++++++ src/target/source/codegen_webgpu.h | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index b76b05470d5d..83079a9f0756 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -433,6 +433,27 @@ void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOL << PrintExpr(op->condition) << ")"; } +void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) + // use ssa form. + if (print_ssa_form_) { + std::string value = PrintExpr(op->value); + ICHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + std::string value = PrintExpr(op->value); + this->stream << "let " << AllocVarID(op->var.get()) << " : "; + PrintType(op->var.dtype(), this->stream); + this->stream << " = " << value << ";\n"; + } + os << PrintExpr(op->body); + // Pop the defined var from var_idmap when exiting its scope. + // We do this because it is hard to completely avoid a same LetNode appearing + // at different places. + bool removed = var_idmap_.erase(op->var.get()); + ICHECK(removed); +} + void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.bits() == 32) { std::ostringstream temp; diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index a100396b25a2..09f99fb88600 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -63,7 +63,8 @@ class CodeGenWebGPU final : public CodeGenC { void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BufferLoadNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) From 384360f628201790ee6b3e821db060a42db8d155 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Mon, 26 Aug 2024 19:29:23 +0800 Subject: [PATCH 083/202] [Relax][Bugfix] Support torch.unbind op and fix bugs for expand && split (#17292) * support unbind * add unit test * format fix * ignore logging in ut --- .../contrib/msc/core/frontend/translate.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 33 ++++- .../contrib/test_msc/test_graph_build.py | 54 +++++++- .../python/contrib/test_msc/test_pipeline.py | 2 +- .../contrib/test_msc/test_translate_relax.py | 41 +++++- .../contrib/test_msc/test_translate_relay.py | 34 ++++- .../test_msc/test_translate_tensorrt.py | 36 ++++- .../contrib/test_msc/test_translate_torch.py | 35 ++++- tests/python/relax/test_frontend_from_fx.py | 128 +++++++++++++++++- 9 files changed, 336 insertions(+), 29 deletions(-) diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 2eaae1335855..63b4424524eb 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -119,6 +119,7 @@ def from_relax( )(mod) patterns = get_patterns_with_prefix("msc.") passes = [ + tvm.relax.transform.ExpandTupleArguments(), msc_transform.SetExprName(), msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern( @@ -310,6 +311,7 @@ def byoc_partition( def _partition_mod(mod, as_msc=True): patterns = get_patterns_with_prefix(target) passes = [ + tvm.relax.transform.ExpandTupleArguments(), msc_transform.SetExprName(), msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 35131d324076..6d01283d3ecd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -526,6 +526,22 @@ def _einsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) + def _unbind(self, node: fx.node.Node) -> relax.Var: + if len(node.args) == 2: + assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" + dim = node.args[1] + elif "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + x = self.env[node.args[0]] + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Manipulation ########## def _cat(self, node: fx.node.Node) -> relax.Var: @@ -535,7 +551,13 @@ def _cat(self, node: fx.node.Node) -> relax.Var: def _expand(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) - return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(args[1:]): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) def _flatten(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -580,7 +602,13 @@ def _split(self, node: fx.node.Node) -> relax.Var: dim = node.kwargs["dim"] else: dim = 0 - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _chunk(self, node: fx.node.Node) -> relax.Var: @@ -1501,6 +1529,7 @@ def create_convert_map(self): "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, "einsum": self._einsum, + "unbind": self._unbind, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 315d6813ea99..069ffff53bd7 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test graph builder for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) - expected = { + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + + expected1 = { "inputs": [ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], @@ -1361,8 +1365,43 @@ def forward(self, data): "nodes": {"total": 2, "input": 1, "split": 1}, } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, + ], + "nodes": {"total": 2, "input": 1, "split": 1}, + } + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Split1(), input_info, expected1) + verify_model(Split2(), input_info, expected2) + + +def test_unbind(): + """test graph builder for unbind""" + + class Unbind(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + expected = { + "inputs": [ + {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "tuple_0", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_1", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_2", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + ], + "nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1}, + } + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info, expected) + verify_model(Unbind(), input_info, expected) def test_cumsum(): @@ -1547,10 +1586,14 @@ def forward(self, x): def test_expand(): """test graph builder for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + expected = { "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], "outputs": [ @@ -1560,7 +1603,8 @@ def forward(self, x): } input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info, expected) + verify_model(Expand1(), input_info, expected) + verify_model(Expand2(), input_info, expected) def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py index c7a26bf96efb..149041959416 100644 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -38,7 +38,7 @@ def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1 path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") return { "workspace": msc_utils.msc_dir(path), - "verbose": "info", + "verbose": "critical", "model_type": model_type, "inputs": inputs, "outputs": outputs, diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 00975be85eca..e8b7149a68a2 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -67,7 +67,12 @@ def _run_relax(relax_mod): orig_output = _run_relax(orig_mod) rt_output = _run_relax(rt_mod) - tvm.testing.assert_allclose(orig_output, rt_output) + if not isinstance(orig_output, (list, tuple)): + orig_output = [orig_output] + if not isinstance(rt_output, (list, tuple)): + rt_output = [rt_output] + for o_out, r_out in zip(orig_output, rt_output): + tvm.testing.assert_allclose(o_out, r_out) def test_conv1d(): @@ -750,12 +755,33 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test relax translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Split(), input_info) + _verify_model(Split1(), input_info) + _verify_model(Split2(), input_info) + + +def test_unbind(): + """test relax translator for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + _verify_model(Unbind1(), input_info) + _verify_model(Unbind2(), input_info) def test_cumsum(): @@ -874,12 +900,17 @@ def forward(self, x): def test_expand(): """test relax translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Expand(), input_info) + _verify_model(Expand1(), input_info) + _verify_model(Expand2(), input_info) def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 6c47b8b39545..3790da3f3d8e 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -731,12 +731,33 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test relay to relax for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info, build_target="llvm") + verify_model(Split1(), input_info, build_target="llvm") + verify_model(Split2(), input_info, build_target="llvm") + + +def test_unbind(): + """test relay to relax for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + verify_model(Unbind1(), input_info, build_target="llvm") + verify_model(Unbind2(), input_info, build_target="llvm") def test_cumsum(): @@ -859,12 +880,17 @@ def forward(self, x): def test_expand(): """test relay to relax for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info, build_target="llvm") + verify_model(Expand1(), input_info, build_target="llvm") + verify_model(Expand2(), input_info, build_target="llvm") def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 81104e6fe0f2..74c25ceacfe8 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -673,12 +673,34 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test tensorrt translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info) + verify_model(Split1(), input_info) + verify_model(Split2(), input_info) + + +@requires_tensorrt +def test_unbind(): + """test tensorrt to relax for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + verify_model(Unbind1(), input_info) + verify_model(Unbind2(), input_info) @requires_tensorrt @@ -697,13 +719,19 @@ def forward(self, data): def test_expand(): """test tensorrt translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): x = x + 1.0 return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + x = x + 1.0 + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info) + verify_model(Expand1(), input_info) + verify_model(Expand2(), input_info) @requires_tensorrt diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 81c6031ce17a..60dcbb293a51 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -728,13 +728,35 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test torch translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] for via_relax in [True, False]: - verify_model(Split(), input_info, via_relax) + verify_model(Split1(), input_info, via_relax) + verify_model(Split2(), input_info, via_relax) + + +def test_unbind(): + """test torch translator for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + for via_relax in [True, False]: + verify_model(Unbind1(), input_info, via_relax) + verify_model(Unbind2(), input_info, via_relax) def test_cumsum(): @@ -835,13 +857,18 @@ def forward(self, x): def test_expand(): """test torch translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] for via_relax in [True, False]: - verify_model(Expand(), input_info, via_relax) + verify_model(Expand1(), input_info, via_relax) + verify_model(Expand2(), input_info, via_relax) def test_reduce(): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 6be3e7b23e9d..5398fe342073 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2714,10 +2714,14 @@ def main( def test_split(): input_info = [([1, 3, 10, 10], "float32")] - class Split(Module): + class Split1(Module): def forward(self, input): return torch.split(input, 1, dim=1) + class Split2(Module): + def forward(self, input): + return torch.split(input, [1, 2], dim=1) + @tvm.script.ir_module class expected1: @R.function @@ -2743,7 +2747,118 @@ def main( R.output(gv) return gv - verify_model(Split(), input_info, {}, expected1) + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 2, 10, 10), dtype="float32") + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 2, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1], axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 2, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split1(), input_info, {}, expected1) + verify_model(Split2(), input_info, {}, expected2) + + +def test_unbind(): + input_info = [([3, 3, 10, 10], "float32")] + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = lv7 + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = lv7 + R.output(gv) + return gv + + verify_model(Unbind1(), input_info, {}, expected1) + verify_model(Unbind2(), input_info, {}, expected2) def test_cumsum(): @@ -2970,10 +3085,14 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype=" def test_expand(): input_info = [([1, 2, 3, 4], "float32")] - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + @tvm.script.ir_module class expected1: @R.function @@ -2987,7 +3106,8 @@ def main( R.output(gv) return gv - verify_model(Expand(), input_info, {}, expected1) + verify_model(Expand1(), input_info, {}, expected1) + verify_model(Expand2(), input_info, {}, expected1) def test_reduce(): From d5d5ebb601a1fee5be3ff52bb8520497db1b99de Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 26 Aug 2024 07:29:40 -0400 Subject: [PATCH 084/202] [Support] Fix the Read/Write of socket stream (#17284) This PR fixes the `dmlc::Stream::Read/Write` for TCP socket. Given socket does not guarantee that all data are send received/sent in a single shot, we need to use `RecvAll/SendAll`. --- src/support/socket.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/support/socket.h b/src/support/socket.h index 032cf257c045..e3972488d4b8 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -553,9 +553,9 @@ class TCPSocket : public Socket, public dmlc::Stream { return data; } - size_t Read(void* data, size_t size) final { return Recv(data, size); } + size_t Read(void* data, size_t size) final { return RecvAll(data, size); } - size_t Write(const void* data, size_t size) final { return Send(data, size); } + size_t Write(const void* data, size_t size) final { return SendAll(data, size); } }; /*! \brief helper data structure to perform poll */ From c4acc79bdec9bd501d1732572843829d7f90c38d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 26 Aug 2024 06:31:58 -0500 Subject: [PATCH 085/202] [Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir (#17243) * [Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir Prior to this commit, the different `R.call_tir*` variations would wrap the arguments into an in-line `relax.Tuple`, if it is not already a `relax.Tuple`. While this allows a tensor to be passed into these functions as a single argument (`R.call_tir(func, arg, ...)` instead of `R.call_tir(func, [arg], ...)`), the wrapped Relax variable may already refer to a tuple. This use of a variable to refer to an argument tuple rather than an in-line argument tuple is not allowed by Relax. (See discussion on https://github.com/apache/tvm/pull/15916 for details.) However, by wrapping a variable `args: R.Tuple(R.Tensor, R.Tensor, ...)` into a tuple-of-tuples, the error occurs after the expression has already been generated, and refers to an expression `R.Tuple(R.Tuple(R.Tensor, R.Tensor, ...))` that doesn't appear anywhere in the user's input. This can make debugging difficult (see https://github.com/apache/tvm/issues/17239 for an example). This commit updates the argument-handling in `R.call_tir` to only generate an in-line `relax.Tuple` if the arguments do not already have `relax.TupleStructInfo`. If the argument was provided as a Relax variable bound to a tuple of arguments, it will still produce an error. However, that error will occur much earlier, and will explicitly state that the argument must be a `relax.Tuple` instead of a `relax.Var`. * lint fixes --- python/tvm/relax/op/base.py | 37 ++++++++++++++++----- tests/python/relax/test_tvmscript_parser.py | 36 ++++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 756d250c1687..03e86a4633a6 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # pylint: disable=redefined-builtin """The base Relax operators.""" + from typing import Dict, Union, List, Tuple, Optional, Callable @@ -25,7 +26,6 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var -from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -67,6 +67,29 @@ def null_value() -> Call: return _ffi_api.null_value() # type: ignore +def _wrap_inline_arg_tuple(args) -> Expr: + """Helper function to wrap argument tuple + + Normalize the arguments provided the functions that accept a tuple + of arguments, and require the tuple of arguments to be written + in-line. If the arguments provided are a single relax expression, + and are not a reference to a relax tuple, then wrap them into an + in-line relax Tuple. + + """ + if ( + isinstance(args, Expr) + and not isinstance(args, tvm.relax.Tuple) + and ( + args.struct_info_ is None + or not isinstance(args.struct_info_, tvm.relax.TupleStructInfo) + ) + ): + return tvm.relax.Tuple([args]) + else: + return args + + @args_converter.auto def call_tir( gvar: GlobalVar, @@ -98,8 +121,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -153,8 +175,7 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -221,8 +242,7 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] @@ -276,8 +296,7 @@ def call_dps_packed( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _wrap_inline_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4f41b662caf2..ea99d49270a1 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1044,6 +1044,42 @@ def main( _check(Module) +def test_call_tir_inplace_with_tuple_var_raises_error(): + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")): + cls = Module + args = (x, y) + res = R.call_tir_inplace( + cls.copy, + # The `args` tuple must be an in-line tuple, not a + # reference to a tuple. This error should be + # caught and raised during parsing. + args, + inplace_indices=[0, -1], + out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + ) + return res + + @T.prim_func + def copy( + A: T.Buffer((2, 3), "int32"), + B: T.Buffer((2, 3), "int32"), + out1: T.Buffer((2, 3), "int32"), + ): + # copies the contents of B into A and out1 + T.func_attr({"tir.noalias": True}) + for iters in T.grid(T.int64(2), T.int64(3)): + with T.block("T_zeros"): + i, j = T.axis.remap("SS", iters) + A[i, j] = B[i, j] + out1[i, j] = B[i, j] + + def test_local_function(): @R.function def main( From c61982e2cd74b29dd43455da390c456e53010307 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 26 Aug 2024 21:55:57 +0800 Subject: [PATCH 086/202] [TE][CreatePrimFunc] Fix create reduce block with spatial iter dependent init value (#17301) fix create reduce block with spatial iter dependent init value Co-authored-by: wrongtest --- src/te/operation/create_primfunc.cc | 17 +++-- tests/python/te/test_te_create_primfunc.py | 73 ++++++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index b5a87d9446d8..31815fc71060 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -228,6 +228,10 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } // Step 4. Create block body. + // helper to transform the expr and remap iters to the block domain + auto f_transform_and_remap = [&](const PrimExpr& e) { + return Substitute(info->transformer(e), var_map); + }; String block_name{nullptr}; Optional init = NullOpt; Stmt body; @@ -246,8 +250,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // - A RHS operand is the value to be reduced. for (int i = 0; i < n_buffers; ++i) { const PrimExpr& left = BufferLoad(buffers[i], indices); - const PrimExpr& right = - analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map)); + const PrimExpr& right = analyzer->Simplify(f_transform_and_remap(reduce->source[i])); lhs.push_back(left); rhs.push_back(right); ICHECK_EQ(left->dtype, right->dtype); @@ -267,13 +270,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // then store the value of the variables into the target buffer positions. for (int i = 0; i < n_buffers; ++i) { const Buffer& buffer = buffers[i]; - init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices)); + PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]); + init_stmts.push_back(BufferStore(buffer, identity, indices)); PrimExpr value{nullptr}; if (n_buffers > 1) { temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype()))); value = temp_vars.back(); } else { - value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + PrimExpr combined = reduce->combiner.get()->operator()(lhs, rhs)[i]; + value = f_transform_and_remap(combined); } body_stmts.push_back(BufferStore(buffer, value, indices)); } @@ -283,7 +288,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, if (n_buffers > 1) { // When there are multiple buffers, we wrap the body with LetStmts. for (int i = n_buffers - 1; i >= 0; --i) { - PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i]; + PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]); body = LetStmt(temp_vars[i], std::move(value), std::move(body)); } } @@ -291,7 +296,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // Case 2. Data parallel compute ICHECK_EQ(tensors.size(), 1); block_name = info->FreshName(tensors[0]->GetNameHint()); - const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map); + const PrimExpr& compute_body = f_transform_and_remap(expr_body); body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); } diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index ade414f4234f..1a7e03188a25 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -814,5 +814,78 @@ def test_with_var_input(): _check_workload(te_slice_with_var_input, tir_slice_with_var_input, index_dtype_override="int64") +def test_loop_aware_initial_value(): + """Test initial value aware of spatial iter position""" + + @T.prim_func + def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + a = T.match_buffer(var_a, (5, 5)) + b = T.match_buffer(var_b, (5,)) + sum_red = T.match_buffer(var_sum_red, (5,)) + for i, ax in T.grid(5, 5): + with T.block("sum_red"): + v_i, v_ax = T.axis.remap("SR", [i, ax]) + T.reads(b[v_i], a[v_i, v_ax]) + T.writes(sum_red[v_i]) + with T.init(): + sum_red[v_i] = b[v_i] + sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax] + + def te_workload(): + data = te.placeholder((5, 5), "float32", "a") + init = te.placeholder((5,), "float32", "b") + ax = te.reduce_axis((0, 5), "ax") + sum_red = te.compute( + (5,), + lambda i: te.comm_reducer( + lambda x, y: x + y, + lambda t: init[i], + )(data[i, ax], axis=[ax]), + name="sum_red", + ) + return [data, init, sum_red] + + _check_workload(te_workload, tir_workload) + + +def test_loop_aware_reducer_combiner(): + """Test combiner aware of spatial iter position""" + + @T.prim_func + def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"}) + a = T.match_buffer(var_a, (5, 5)) + b = T.match_buffer(var_b, (5,)) + sum_red = T.match_buffer(var_sum_red, (5,)) + for i, ax in T.grid(5, 5): + with T.block("sum_red"): + v_i = T.axis.spatial(5, i) + v_ax = T.axis.reduce(5, ax) + T.reads(a[v_i, 0:5]) + T.writes(sum_red[v_i]) + with T.init(): + sum_red[v_i] = T.float32(0.0) + sum_red[v_i] = T.if_then_else( + a[v_i, sum_red[v_i]] < a[v_i, v_ax], sum_red[v_i], T.Cast("float32", v_ax) + ) + + def te_workload(): + data = te.placeholder((5, 5), "float32", "a") + init = te.placeholder((5,), "float32", "b") + ax = te.reduce_axis((0, 5), "ax") + sum_red = te.compute( + (5,), + lambda i: te.comm_reducer( + lambda x, y: te.if_then_else(data[i, x] < y, x, ax), + lambda _: te.const(0, "float32"), + )(data[i, ax], axis=[ax]), + name="sum_red", + ) + return [data, init, sum_red] + + _check_workload(te_workload, tir_workload) + + if __name__ == "__main__": tvm.testing.main() From 3138328207bbe0b519c33a2f59be8ef2cf44d5b7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 26 Aug 2024 21:20:05 -0400 Subject: [PATCH 087/202] [Runtime] Support KV cache with RoPE extension factor array (#17294) This PR enhances the KV cache with the RoPE extensio factor support. With this PR, the KV cache can support models like Phi3.5 which comes with the extension factor. --- src/runtime/relax_vm/kv_state.h | 1 + src/runtime/relax_vm/paged_kv_cache.cc | 63 +++++++++++-------- ...tin_paged_attention_kv_cache_flashinfer.py | 3 + ...me_builtin_paged_attention_kv_cache_tir.py | 1 + 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f4d6036b9638..6d30ce998add 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -167,6 +167,7 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6bf3dc7ce609..591187ab5fe7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -848,6 +848,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const double rotary_scale_; /*! \brief The RoPE theta. */ const double rotary_theta_; + /*! \brief The optional RoPE extension factors for RoPE scaling. */ + const Optional rope_ext_factors_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -988,7 +990,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, + Optional rope_ext_factors, DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, @@ -1013,6 +1016,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : rope_mode), rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), + rope_ext_factors_(std::move(rope_ext_factors)), f_transpose_append_(std::move(f_transpose_append)), f_compact_copy_(std::move(f_compact_copy)), f_attention_prefill_(std::move(f_attention_prefill)), @@ -1132,6 +1136,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, preferred_host_device, copy_stream_); } + + // Right now only the "normal" RoPE mode supports the RoPE extention factors. + if (rope_ext_factors_.defined()) { + CHECK(rope_mode_ == RoPEMode::kNormal) + << "The RoPE mode must be normal to support RoPE extension factors."; + } } ~PagedAttentionKVCacheObj() { @@ -1726,8 +1736,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - static_cast(rope_mode_ == RoPEMode::kNormal)); + if (!rope_ext_factors_.defined()) { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + static_cast(rope_mode_ == RoPEMode::kNormal)); + } else { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + rope_ext_factors_.value()); + } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { @@ -2462,7 +2477,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) + CHECK(args.size() == 27 || args.size() == 28) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2499,14 +2514,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_split_rotary = args[22]; PackedFunc f_copy_single_page = args[23]; Optional f_debug_get_kv = args[24]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[25]; + PackedFunc f_attention_prefill_with_tree_mask = args[26]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 26) { - f_compact_copy = args[25].AsObjectRef(); - } - if (args.size() >= 27) { - f_attention_prefill_with_tree_mask = args[26].AsObjectRef(); + if (args.size() >= 28 && args[27].IsObjectRef()) { + rope_ext_factors = args[27].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2523,9 +2536,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), @@ -2539,7 +2553,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) + CHECK(args.size() == 21 || args.size() == 22) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2570,14 +2584,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") PackedFunc f_split_rotary = args[16]; PackedFunc f_copy_single_page = args[17]; Optional f_debug_get_kv = args[18]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[19]; + PackedFunc f_attention_prefill_with_tree_mask = args[20]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 20) { - f_compact_copy = args[19].AsObjectRef(); - } - if (args.size() >= 21) { - f_attention_prefill_with_tree_mask = args[20].AsObjectRef(); + if (args.size() >= 22 && args[21].IsObjectRef()) { + rope_ext_factors = args[21].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2594,9 +2606,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index cab10f84cddf..2252cb8d9c09 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -379,6 +379,9 @@ def create_kv_cache(rope_mode): fsplit_rotary, fcopy_single_page, fcopy_cache, + None, + None, + None, ) return cache diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 96a2438505b2..ff655e141b96 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -180,6 +180,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcopy_cache, fcompact_copy, fattn_prefill_with_tree_mask, + None, ) return cache From bf7bbefd36ac91242496d533d2bfff71570bf04a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 27 Aug 2024 10:19:28 -0400 Subject: [PATCH 088/202] [Python][Relax] Rotary positional embedding scaling (#17305) This PR introduces two styles of RoPE scaling: the llama3 style and the longrope scale. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 396 ++++++++++++++++-- .../frontend/nn/llm/position_embedding.py | 191 ++++++++- python/tvm/relax/frontend/nn/llm/tree_attn.py | 26 +- ...me_builtin_paged_attention_kv_cache_tir.py | 19 +- 4 files changed, 579 insertions(+), 53 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 25a3a1a00ddc..5ddce76eab40 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -20,7 +20,7 @@ # pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name import enum import math -from typing import Tuple +from typing import Any, Dict, Tuple from tvm import relax as rx from tvm import tir @@ -29,7 +29,7 @@ from tvm.script import tir as T from tvm.target import Target -from .position_embedding import llama_rope_with_position_map, rope_freq +from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func from .tree_attn import tree_attn @@ -166,6 +166,8 @@ def __init__( # pylint: disable=too-many-locals rope_mode: RopeMode, rope_scale: int, rope_theta: int, + rope_scaling: Dict[str, Any], + rope_ext_factors: rx.Expr, rotary_dim: int, dtype: str, target: Target, @@ -195,6 +197,9 @@ def __init__( # pylint: disable=too-many-locals 0 or 1, denoting whether the KV cache supports sliding window. It is a symbolic variable whose concrete value is specified at runtime. + layer_partition : rx.ShapeExpr + The KV cache layer partition for pipeline stages. + It is an indptr array, denoting the starting layer of each pipeline stage. rope_mode : RopeMode The RoPE mode of the Paged KV cache. If it is normal, RoPE will be applied to k before adding k to cache. @@ -205,6 +210,8 @@ def __init__( # pylint: disable=too-many-locals The base of rotary position embedding. rope_scaling: Dict[str, Any] The RoPE scaling information dict. + rope_ext_factors: rx.Expr + The RoPE extension factors when "longrope" mode RoPE scaling is enabled. rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. """ @@ -235,8 +242,8 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), @@ -245,11 +252,12 @@ def __init__( # pylint: disable=too-many-locals rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + rope_ext_factors, # fmt: on # pylint: enable=line-too-long ] @@ -281,6 +289,8 @@ def __init__( # pylint: disable=too-many-locals head_dim: int, rope_scale: int, rope_theta: int, + rope_scaling: Dict[str, Any], + rope_ext_factors: rx.Expr, rotary_dim: int, dtype: str, target: Target, @@ -321,6 +331,10 @@ def __init__( # pylint: disable=too-many-locals The scale of rotary position embedding. rope_theta : int The base of rotary position embedding. + rope_scaling: Dict[str, Any] + The RoPE scaling information dict. + rope_ext_factors: rx.Expr + The RoPE extension factors when "longrope" mode RoPE scaling is enabled. rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. target : Target @@ -349,17 +363,18 @@ def __init__( # pylint: disable=too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged"), bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + rope_ext_factors, # fmt: on # pylint: enable=line-too-long ] @@ -464,17 +479,23 @@ def _rope( theta: tir.Var, scale: tir.Var, indices: Tuple[tir.Var, ...], - qkv_dtype="float16", + qkv_dtype: str, + rope_scaling: Dict[str, Any], ): d = indices[-1] - cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, "float32") + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + offset * scale, d, rotary_dim, theta, "float32" + ) cos = cos_freq * buffer[indices].astype("float32") sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -buffer[indices[:-1] + (d + rotary_dim // 2,)], buffer[indices[:-1] + (d - rotary_dim // 2,)], ).astype("float32") - return (cos + sin).astype(qkv_dtype) + expr = (cos + sin).astype(qkv_dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr def _var(dtype): @@ -520,7 +541,9 @@ def _get_seq_offset(pos, seq_id, length_info, sliding_window): ) -def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target): +def _attention_prefill( + h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target +): NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -680,7 +703,7 @@ def batch_prefill_paged_kv( if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling), q[cur_L, cur_H_qo, j] ) else: @@ -701,7 +724,7 @@ def batch_prefill_paged_kv( page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore K_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), + _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype, rope_scaling), pages[page_no, 0, by, page_offset, j] ) else: @@ -890,6 +913,7 @@ def _attention_decode( head_dim, qkv_dtype, sliding_window: bool, + rope_scaling: Dict[str, Any], target: Target, ): qkv_dtype_bytes = 2 @@ -1023,7 +1047,7 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): Q_local[vec] = T.if_then_else( rotary_mode == 1, - _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype), + _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype, rope_scaling), Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] ) @@ -1043,7 +1067,7 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( rotary_mode == 1, - _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), + _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype, rope_scaling), pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] ) V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] @@ -1210,7 +1234,331 @@ def merge_state_inplace( return merge_state_inplace -def _attention_prefill_ragged(h_kv, h_q, d, dtype, target: Target): +def _attention_sequence_prefill( + batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 +): # pylint: disable=line-too-long + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + + # fmt: off + @T.prim_func + def batch_sequence_prefill_kv( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle # [total_len, h_q] + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype) + k = T.match_buffer(var_k, (batch_size, kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (batch_size, kv_len, h_kv, d), dtype) + output = T.match_buffer(var_output, (batch_size, qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (batch_size, qo_len, h_q), dtype) # pylint: disable=unused-variable + + batch_tiles: T.int32 = T.ceildiv(qo_len * group_size, tile_x) + + # kernel code + for lbx in T.thread_binding(T.cast(batch_size, "int32") * batch_tiles, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + vbx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + + m_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + m_prev = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + d_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + + b_idx: T.int32 = vbx // batch_tiles + tile_id: T.int32 = vbx % batch_tiles + LH_start: T.int32 = tile_id * tile_x + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < qo_len: + Q_smem[i, j] = q[b_idx, cur_L, cur_H_qo, j] + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_len, tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = 0 + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_len: + K_smem[i, j] = k[ + b_idx, L_kv_base + cur_L, by, j + ] + else: + K_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_len: + V_smem[i, j] = v[ + b_idx, L_kv_base + cur_L, by, j + ] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += ( + T.cast(Q_smem[i, k], "float32") + * T.cast(K_smem[j, k], "float32") + * attn_score_scaling_factor + * sm_scale + ) + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _causal_mask( + causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_len, + qo_len=qo_len, + ): + m_new[i] = T.max( + m_new[i], S_smem[row, j] + ) + d_new[i] = d_smem[row] * T.exp2( + m_prev[i] - m_new[i] + ) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = ( + LH_start + row + ) // group_size + if _causal_mask( + causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_len, + qo_len=qo_len, + ): + S_smem[row, j] = T.exp2( + S_smem[row, j] - m_new[i] + ) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2( + m_prev_smem[i] - m_smem[i] + ) + O_local[i, j] += S_smem[i, k] * T.cast( + V_smem[k, j], "float32" + ) + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = 0 + (LH_start + i) // group_size + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < qo_len: + output[b_idx, cur_L, cur_H_qo, j] = ( + O_local[i, j] / d_smem[i] + ) + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = 0 + (LH_start + i) // group_size + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < qo_len: + lse[b_idx, cur_L, cur_H_qo] = m_smem[i] + T.log2( + d_smem[i] + ) + + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_sequence_prefill_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_schedule(sch): + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + + apply_schedule(sch) + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): # pylint: disable=line-too-long NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes @@ -1344,7 +1692,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling), q[cur_L, cur_H_qo, j] ) else: @@ -1363,7 +1711,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype), + _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype, rope_scaling), k[L_kv_base + cur_L, by, j] ) else: diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index b224ce04c597..4373395e3214 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -17,7 +17,9 @@ """Operators for positional embeddings, e.g. RoPE.""" -from typing import Optional, Tuple +import math +from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple from tvm import tir from tvm.relax.frontend.nn import Tensor, op @@ -26,7 +28,7 @@ # pylint: disable=invalid-name -def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): +def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): """Compute the inverse frequency of RoPE and then return the cosine and sine of it. Parameters @@ -53,11 +55,95 @@ def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): sin_freq : Tensor The sine of the inverse frequency. + + var_map: Dict[tir.Var, tir.PrimExpr] + The common expression map. """ freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) - cos_freq = tir.cos(freq).astype(dtype) - sin_freq = tir.sin(freq).astype(dtype) - return cos_freq, sin_freq + freq_var = tir.Var("freq", "float32") + cos_freq = tir.cos(freq_var).astype(dtype) + sin_freq = tir.sin(freq_var).astype(dtype) + return cos_freq, sin_freq, {freq_var: freq} + + +def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: float, +): + """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" + orig_freq = tir.const(1, "float32") / tir.power( + theta, d * 2 % d_range / tir.const(d_range, "float32") + ) + orig_freq_var = tir.Var("orig_freq", "float32") + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + llama3_inv_scaling_factor = 1.0 / factor + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama3_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + ) + smoothed_freq_var = tir.Var("smoothed_freq", "float32") + cos_freq = tir.cos(smoothed_freq_var).astype(dtype) + sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + + +def rope_freq_longrope( # pylint: disable=too-many-arguments + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + max_position_embeddings: int, + original_max_position_embeddings: int, + ext_factors: Optional[T.Buffer] = None, +): + """Compute the inverse frequency of RoPE for longrope scaling.""" + scale = max_position_embeddings / original_max_position_embeddings + scaling_factor = ( + math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) + if scale > 1.0 + else 1.0 + ) + divisor = tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) + if ext_factors is not None: + divisor = ext_factors[d % (d_range // 2)] * divisor + freq = s / divisor + freq_var = tir.Var("freq", "float32") + cos_freq = (tir.cos(freq_var) * scaling_factor).astype(dtype) + sin_freq = (tir.sin(freq_var) * scaling_factor).astype(dtype) + return cos_freq, sin_freq, {freq_var: freq} + + +def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: + """Return the RoPE inverse frequency computation function based + on the given RoPE scaling. + """ + if "rope_type" not in rope_scaling: + return rope_freq_default + if rope_scaling["rope_type"] == "llama3": + return partial( + rope_freq_llama3, + factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) + if rope_scaling["rope_type"] == "longrope": + return partial( + rope_freq_longrope, + max_position_embeddings=rope_scaling["max_position_embeddings"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) + raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}') # mypy: disable-error-code="attr-defined" @@ -67,9 +153,10 @@ def llama_rope( # pylint: disable=too-many-arguments qkv: Tensor, total_seq_len: tir.Var, theta: float, + scale: float, num_q_heads: int, num_kv_heads: int, - scale: float = 1.0, + rope_scaling: Dict[str, Any], rotary_dim: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Llama-style RoPE. Given a fused QKV tensor, it returns three tensors, Q, K, and V, where Q @@ -96,6 +183,9 @@ def llama_rope( # pylint: disable=too-many-arguments num_kv_heads : int The number of key/value heads. It differs from `num_q_heads` in group-query attention. + rope_scaling : Dict + The configuration of RoPE scaling. + rotary_dim : Optional[int] The number of dimensions in the embedding that RoPE is applied to. By default, the rotary_dim is the same as head_dim. @@ -126,14 +216,19 @@ def _rope( # pylint: disable=too-many-arguments d: tir.Var, offset: tir.Var, ): - cos_freq, sin_freq = rope_freq((s + offset) * scale, d, rotary_dim, theta, dtype) + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + (s + offset) * scale, d, rotary_dim, theta, dtype + ) cos = cos_freq * x[b, s, h, d] sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -x[b, s, h, d + rotary_dim // 2], x[b, s, h, d - rotary_dim // 2], ) - return cos + sin + expr = cos + sin + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals @@ -193,6 +288,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments num_q_heads: int, num_kv_heads: int, dtype: str, + rope_scaling: Dict[str, Any], rotary_dim: Optional[int] = None, ): """Return the TIR function that computes Llama-style RoPE with q position map. @@ -217,6 +313,9 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments dtype : str The dtype of qkv data. + rope_scaling : Dict + The configuration of RoPE scaling. + rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. By default, the rotary_dim is the same as head_dim. @@ -225,6 +324,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments if rotary_dim is None: rotary_dim = head_dim scale = tir.const(scale, "float32") + is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" def _rope( # pylint: disable=too-many-arguments x: T.Buffer, @@ -232,15 +332,24 @@ def _rope( # pylint: disable=too-many-arguments h: tir.Var, d: tir.Var, pos: tir.Var, + ext_factors: Optional[T.Buffer] = None, ): - cos_freq, sin_freq = rope_freq(pos * scale, d, rotary_dim, theta, "float32") + kwargs = {} + if ext_factors: + kwargs["ext_factors"] = ext_factors + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + pos * scale, d, rotary_dim, theta, "float32", **kwargs + ) cos = cos_freq * x[s, h, d].astype("float32") sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -x[s, h, d + rotary_dim // 2], x[s, h, d - rotary_dim // 2], ).astype("float32") - return (cos + sin).astype(dtype) + expr = (cos + sin).astype(dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr @T.prim_func def fused_rope( # pylint: disable=too-many-locals @@ -257,8 +366,8 @@ def fused_rope( # pylint: disable=too-many-locals "tir.noalias": T.bool(True), } ) - seq_len = T.int64() - position_map_elem_offset = T.int64() + seq_len = T.int32() + position_map_elem_offset = T.int32() qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) @@ -284,4 +393,62 @@ def fused_rope( # pylint: disable=too-many-locals else: v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + @T.prim_func + def fused_rope_longrope_scaling( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + ext_factors: T.Buffer((head_dim // 2,), "float32"), # type: ignore + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": T.bool(True), + } + ) + seq_len = T.int64() + position_map_elem_offset = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + if is_longrope_scaling: + return fused_rope_longrope_scaling return fused_rope diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 486491dbf2c6..069eb4892348 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -19,14 +19,14 @@ """Operators for tree attention.""" import math -from typing import Tuple +from typing import Any, Dict, Tuple from tvm import tir from tvm.runtime import DataType from tvm.script import tir as T from tvm.target import Target -from .position_embedding import rope_freq +from .position_embedding import switch_rope_freq_func # mypy: disable-error-code="attr-defined,valid-type,no-redef" # pylint: disable=too-many-statements,too-many-locals,too-many-arguments @@ -43,24 +43,30 @@ def _rope( theta: tir.Var, scale: tir.Var, indices: Tuple[tir.Var, ...], - qkv_dtype="float16", + qkv_dtype: str, + rope_scaling: Dict[str, Any], ): d = indices[-1] - cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) - cos = cos_freq * buffer[indices] + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + offset * scale, d, rotary_dim, theta, "float32" + ) + cos = cos_freq * buffer[indices].astype("float32") sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -buffer[indices[:-1] + (d + rotary_dim // 2,)], buffer[indices[:-1] + (d - rotary_dim // 2,)], - ) - return cos + sin + ).astype("float32") + expr = (cos + sin).astype(qkv_dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) -def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument +def tree_attn(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): """Generate tree attention kernel for batched tree attention. Parameters @@ -217,7 +223,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling), q[cur_L, cur_H_qo, j] ) else: @@ -236,7 +242,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches if L_kv_start + i < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype), + _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype, rope_scaling), k[cur_L, by, j] ) V_smem[i, j] = v[cur_L, by, j] diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index ff655e141b96..c35b7062cdc2 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -49,6 +49,7 @@ head_dim = None rope_scale = 1.0 rope_theta = 1e4 +rope_scaling = {} dtype = None device = tvm.cuda() @@ -113,15 +114,19 @@ def set_global_func(head_dim, dtype): for tir_func in [ _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype), - _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, False, target), - _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, False, target), - _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), - _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), - _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype, target), - tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, target), + _attention_prefill( + num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling, target + ), + _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling, target), + _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), + _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), + _attention_prefill_ragged( + num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target + ), + tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target), _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( - rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype + rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, rope_scaling ), _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), _compact_kv_copy(num_kv_heads, head_dim, dtype, target), From 99defd25c40c75b00395df1d2d58c84d2e0bd9ca Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 28 Aug 2024 04:37:30 +0900 Subject: [PATCH 089/202] [Relax][PyTorch] Add support for torch.repeat (#17304) * add test * add support for torch.repeat * remove debug print --- .../tvm/relax/frontend/torch/fx_translator.py | 9 +++++ tests/python/relax/test_frontend_from_fx.py | 36 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6d01283d3ecd..676f63b5c359 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -640,6 +640,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _repeat(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _tile(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore @@ -1484,6 +1492,7 @@ def create_convert_map(self): "expand": self._expand, "flatten": self._flatten, "permute": self._permute, + "repeat": self._repeat, "reshape": self._reshape, "split": self._split, "tile": self._tile, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 5398fe342073..c6c4f2597260 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3311,6 +3311,42 @@ def main( verify_model(Transpose(), input_info, {}, expected1) +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tensor((6,), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), [([3], "float32")], {}, expected1) + verify_model(Tile2(), [([1, 3], "float32")], {}, expected2) + verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2) + + def test_view(): input_info = [([1, 2, 3, 4], "float32")] From be8607d47fa418f6bf77671b81093e0ffd7fdc4d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Aug 2024 17:43:54 -0500 Subject: [PATCH 090/202] [Relax][Bugfix] Infer TIR values from shapes inside a tuple (#17312) If a Relax function contains an `R.match_cast` that defines a symbolic shape, and the value provided to the `R.match_cast` has a known static shape, the `relax.transform.CanoncalizeBindings()` pass can in-line the known static shape. However, while these known TIR values were only collected if the expression used in `R.match_cast` was a `R.Tensor`, `R.Shape`, and `R.Prim` (Relax types which may contain symbolic TIR values), they were not collected if the `R.match_cast` expression was a `R.Tuple`. For example, while using `R.match_cast` to convert from `R.Tensor([16])` to `R.Tensor([batch_size])` would identify that `batch_size` must be `16`, using `R.match_cast` to convert from `R.Tuple(R.Tensor([16]))` to `R.Tuple(R.Tensor([batch_size]))` would not. This commit updates the `InferSymbolicVarMap` to collect all symbolic shapes, even if they occur within a `R.Tuple`. --- src/relax/utils.cc | 27 ++++++++++++--- .../test_transform_canonicalize_bindings.py | 34 +++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 77416dc92b1d..96fd5578e40a 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -159,13 +159,32 @@ tvm::Map InferSymbolicVarMap( GetStructInfo(expr_tensor->shape.value())); }; + std::function bind_from_struct_info = nullptr; + auto bind_from_tuple = [&bind_from_struct_info](const StructInfo& var, const StructInfo& expr) { + auto var_tuple = var.as(); + if (!var_tuple) return; + + auto expr_tuple = expr.as(); + if (!expr_tuple) return; + + if (var_tuple->fields.size() != expr_tuple->fields.size()) return; + + for (size_t i = 0; i < var_tuple->fields.size(); i++) { + bind_from_struct_info(var_tuple->fields[i], expr_tuple->fields[i]); + } + }; + + bind_from_struct_info = [&](const StructInfo& var, const StructInfo& expr) { + bind_from_tensor(var, expr); + bind_from_shape(var, expr); + bind_from_prim_value(var, expr); + bind_from_tuple(var, expr); + }; + for (const auto& [relax_var, relax_expr] : relax_var_remap) { auto var_sinfo = GetStructInfo(relax_var); auto expr_sinfo = GetStructInfo(relax_expr); - - bind_from_tensor(var_sinfo, expr_sinfo); - bind_from_shape(var_sinfo, expr_sinfo); - bind_from_prim_value(var_sinfo, expr_sinfo); + bind_from_struct_info(var_sinfo, expr_sinfo); } return tir_var_remap; diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index ea3b1c249b8b..a7ff8cdc3202 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -253,6 +253,40 @@ def main(x: R.Tensor(("m", "n"))): verify(TestChangeShape, Expected) +def test_replace_symbolic_variable_and_remove_match_cast_of_tuple(): + """Symbolic variables may be defined in R.match_cast of tuple + + This test is similar to + `test_replace_symbolic_variable_and_remove_match_cast`, except + that the MatchCast is performed on a Relax tuple. + + This is a regression test. Earlier implementations only inferred + TIR variables from `R.match_cast` of tensors, shapes, and prim + values, but omitted tuples. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tuple(R.Tensor(("m", "n")))): + y = x + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tuple(R.Tensor((o, p)))) + w = z + q = R.add(w[0], y[0]) + return R.add(q, w[0]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tuple(R.Tensor(("m", "n")))): + q = R.add(x[0], x[0]) + return R.add(q, x[0]) + + verify(Before, Expected) + + def test_unwrap_tuple(): @I.ir_module class Before: From 108a4e15b3c68fea2f803dc13b1b45291b00f15b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Aug 2024 18:29:18 -0500 Subject: [PATCH 091/202] [Relax] Identify tuple unpack/repack in CanonicalizeBindings (#17313) Prior to this commit, the `CanonicalizeBindings` pass could identify and simplify a value that had been packed into a tuple, then extracted from it. (e.g. Simplifying `tup = (x,y); z = tup[0]` into `z = x`.) However, it could not identify a value that had been expanded from a tuple, and then re-bundled. (e.g. Simplifying `new_tuple = (tup[0], tup[1])` into `new_tuple = tup`.) This commit updates `CanonicalizeBindings` to identify and remove unnecessary tuple unpacking/repacking. --- src/relax/transform/canonicalize_bindings.cc | 112 ++++++++++++++---- .../test_transform_canonicalize_bindings.py | 51 ++++++++ 2 files changed, 143 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index d1a9f97337de..807914075e8d 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor { current_block_ = Optional(); } - void VisitBinding(const Binding& binding) override { - bool has_same_struct_info = true; - Expr value; - if (auto ptr = binding.as()) { - value = ptr->value; - } else if (auto ptr = binding.as()) { - has_same_struct_info = - StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value)); - value = ptr->value; - } else { - LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); - } + Optional UnwrapKnownValue(Expr expr) { + // If the expression is a variable, then it can be unwrapped into + // its known value. + auto unwrap_var = [this](Expr expr) -> Expr { + if (auto var = expr.as()) { + if (auto opt = known_bindings_.Get(var.value())) { + return opt.value(); + } + } + return expr; + }; - // Unwrap TupleGetItem, if the Tuple being accessed is known. - if (auto tuple_get_item = value.as()) { - Expr tuple = tuple_get_item->tuple; - while (auto tuple_var = tuple.as()) { - if (auto opt = known_bindings_.Get(tuple_var.value())) { - tuple = opt.value(); + auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr { + while (true) { + auto new_expr = unwrap_var(expr); + if (new_expr.same_as(expr)) { + return expr; } else { - break; + expr = new_expr; } } + }; + // If the expression is a TupleGetItem, which accesses a field of + // a known tuple, then it can be unwrapped into a direct access of + // that field. + if (auto tuple_get_item = expr.as()) { + Expr tuple = recursively_unwrap_var(tuple_get_item->tuple); if (auto ptr = tuple.as()) { - value = ptr->fields[tuple_get_item->index]; + return ptr->fields[tuple_get_item->index]; + } + } + + // If the expression is a Tuple, and each element is + // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of + // `earlier_tuple`. + auto earlier_tuple = [&]() -> Optional { + auto expr_tuple = expr.as(); + if (!expr_tuple) { + return NullOpt; + } + + if (expr_tuple->fields.empty()) { + return NullOpt; + } + + auto first_element = recursively_unwrap_var(expr_tuple->fields[0]).as(); + if (!first_element) { + return NullOpt; + } + + auto earlier_tuple_size = + Downcast(GetStructInfo(first_element->tuple))->fields.size(); + if (earlier_tuple_size != expr_tuple->fields.size()) { + return NullOpt; } + + Expr earlier_tuple = recursively_unwrap_var(first_element->tuple); + + for (size_t i = 0; i < expr_tuple->fields.size(); i++) { + auto element = recursively_unwrap_var(expr_tuple->fields[i]).as(); + if (!element) { + return NullOpt; + } + if (static_cast(element->index) != i) { + return NullOpt; + } + + auto source_of_element = recursively_unwrap_var(element->tuple); + + if (!earlier_tuple.same_as(source_of_element)) { + return NullOpt; + } + } + + return earlier_tuple; + }(); + if (earlier_tuple) { + return earlier_tuple.value(); + } + + return NullOpt; + } + + void VisitBinding(const Binding& binding) override { + bool has_same_struct_info = [&]() { + if (binding.as()) { + return true; + } else if (auto match_cast = binding.as()) { + return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value)); + } else { + LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); + } + }(); + + Expr value = GetBoundValue(binding); + + if (auto unwrapped = UnwrapKnownValue(value)) { + value = unwrapped.value(); } if (auto parent = value.as(); parent && has_same_struct_info) { diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index a7ff8cdc3202..1d982b0972ed 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -1294,5 +1294,56 @@ def _get_binding_names(mod): assert after_names == expected_names +def test_trace_tuple_through_round_trip(): + """Canonicalize to the orignal tuple, without unwrap/rewrap.""" + + @I.ir_module + class Before: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + C = param_tuple[2] + output = (A, B, C) + R.output(output) + return output + + @I.ir_module + class Expected: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + C = param_tuple[2] + R.output() + + return param_tuple + + After = CanonicalizeBindings()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_trace_partial_tuple_through_round_trip(): + """Canonicalize to the orignal tuple, without unwrap/rewrap.""" + + @I.ir_module + class Before: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + output = (A, B) + R.output(output) + return output + + Expected = Before + + After = CanonicalizeBindings()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main() From 6ca0bea2d89bf11a315332983486437b6a4a90f2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 28 Aug 2024 19:31:02 -0400 Subject: [PATCH 092/202] [Fix][TIR] LowerThreadAllreduce warp reduction mask (#17307) The warp reduction implemented by "shuffle down" primitive takes a mask denoting the active threads within the warp that participate in this shuffle. Previously we compute the mask, while in practice we find that it results in "CUDA illegal instruction" error on NVIDIA H100 GPU when the mask is set, and the issue is gone if we do not update the mask. Therefore, this PR updates the allreduce lowering to remove the mask update. Confirmed the correctness on the following devices: * NVIDIA H100, * NVIDIA RTX 4090, * AMD Radeon 7900 XTX, * Apple M2 Ultra. --- src/tir/transforms/lower_thread_allreduce.cc | 7 ------- .../test_tir_transform_lower_thread_all_reduce.py | 15 ++++----------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 37d8f67580fe..dde33fa2678d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (reduce_extent <= warp_size_) { - if (group_extent > 1 && reduce_extent < warp_size_) { - mask = mask & - (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); - } std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq); @@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{group_index * n_warps + reduce_index}); } - if (n_warps < warp_size_) { - mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps)); - } std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, n_warps, group_index, mask, /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq); diff --git a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py index d8c9568da90e..18d6339349ff 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py @@ -342,10 +342,7 @@ def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): t0 = T.decl_buffer([1], "float32", scope="local") A_1 = T.Buffer((256,), data=A.data) red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), - T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)), - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32) red_buf0_1[0] = red_buf0_1[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32) @@ -421,7 +418,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_x] - mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -573,9 +570,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 4: red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) @@ -657,9 +652,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): T.tvm_storage_sync("shared") if threadIdx_x < 16: red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x] - mask[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) - ) + mask[0] = T.tvm_warp_activemask() t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32) red_buf0[0] = red_buf0[0] + t0[0] t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32) From 2b56ce6c669b6325889af407cd6858a055c17f14 Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Thu, 29 Aug 2024 17:58:00 +0800 Subject: [PATCH 093/202] [Relax][Frontend][Onnx] fix expand bug in onnx frontend (#17309) * fix expand bug in onnx frontend * add test expand_with_diff_dim --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 ++ tests/python/relax/test_frontend_onnx.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 85d4402d6640..c3116f9988ce 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1135,6 +1135,8 @@ def _impl_v13(cls, bb, inputs, attr, params): # For some reason, onnx allows target shapes to be smaller than input shapes. # We need to go correct it. data_shape = [dim.value for dim in data.struct_info.shape] + # Dimensions are right alignment. + data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape # Fix small target shapes. for i, s in enumerate(new_shape): if i < len(data_shape) and s < data_shape[i]: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 05316f2699dd..3ea987973578 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1118,6 +1118,12 @@ def _test_expand(name, data, shape, ref_data): ref_data = np.tile(data, 4) _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) + in_shape = (3, 1) + shape = (1, 3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, (1, 1, 4)) + _test_expand("expand_with_diff_dim", data, shape, ref_data) + # TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. @pytest.mark.skip("Produces ill-formed IR") From add93d7372cf255b4f1fb094c7d1e0eb8ae25321 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 29 Aug 2024 19:08:36 +0800 Subject: [PATCH 094/202] [Doc] Refactor How-To (#17306) This PR refactors the how-to section and add new tutorials of `end-to-end optimization model` --- docs/conf.py | 2 + docs/dev/how_to/how_to.rst | 2 - docs/how_to/dev/index.rst | 28 ++++ .../dev}/pytest_target_parametrization.rst | 0 .../dev}/setup_rpc_system.rst | 6 +- docs/how_to/index.rst | 22 +-- docs/how_to/legacy_index.rst | 38 +++++ docs/how_to/tutorials/README.txt | 2 + .../tutorials}/cross_compilation_and_rpc.py | 0 docs/how_to/tutorials/e2e_opt_model.py | 139 ++++++++++++++++++ docs/index.rst | 16 +- gallery/tutorial/install.py | 50 ------- gallery/tutorial/introduction.py | 2 - 13 files changed, 221 insertions(+), 86 deletions(-) create mode 100644 docs/how_to/dev/index.rst rename docs/{dev/how_to => how_to/dev}/pytest_target_parametrization.rst (100%) rename docs/{dev/how_to => how_to/dev}/setup_rpc_system.rst (99%) create mode 100644 docs/how_to/legacy_index.rst create mode 100644 docs/how_to/tutorials/README.txt rename {gallery/tutorial => docs/how_to/tutorials}/cross_compilation_and_rpc.py (100%) create mode 100644 docs/how_to/tutorials/e2e_opt_model.py delete mode 100644 gallery/tutorial/install.py diff --git a/docs/conf.py b/docs/conf.py index 1c5c5cb5d602..c933653233b1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -423,6 +423,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): tvm_path.joinpath("vta", "tutorials"), # New tutorial structure under docs folder tvm_path.joinpath("docs", "get_started", "tutorials"), + tvm_path.joinpath("docs", "how_to", "tutorials"), ] gallery_dirs = [ @@ -440,6 +441,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): "topic/vta/tutorials", # New tutorial structure under docs folder "get_started/tutorials/", + "how_to/tutorials/", ] diff --git a/docs/dev/how_to/how_to.rst b/docs/dev/how_to/how_to.rst index 1e1d1236bd51..aa89324fb949 100644 --- a/docs/dev/how_to/how_to.rst +++ b/docs/dev/how_to/how_to.rst @@ -29,5 +29,3 @@ various areas of the TVM stack. relay_add_op relay_add_pass relay_bring_your_own_codegen - pytest_target_parametrization - setup_rpc_system diff --git a/docs/how_to/dev/index.rst b/docs/how_to/dev/index.rst new file mode 100644 index 000000000000..c70832358a41 --- /dev/null +++ b/docs/how_to/dev/index.rst @@ -0,0 +1,28 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Develope Apache TVM +=================== +This section contains a collection of tips about how to work on +various areas of the TVM stack. + +.. toctree:: + :maxdepth: 1 + + pytest_target_parametrization + setup_rpc_system + ../../errors diff --git a/docs/dev/how_to/pytest_target_parametrization.rst b/docs/how_to/dev/pytest_target_parametrization.rst similarity index 100% rename from docs/dev/how_to/pytest_target_parametrization.rst rename to docs/how_to/dev/pytest_target_parametrization.rst diff --git a/docs/dev/how_to/setup_rpc_system.rst b/docs/how_to/dev/setup_rpc_system.rst similarity index 99% rename from docs/dev/how_to/setup_rpc_system.rst rename to docs/how_to/dev/setup_rpc_system.rst index 061aa5b07b9c..0131619b71d2 100644 --- a/docs/dev/how_to/setup_rpc_system.rst +++ b/docs/how_to/dev/setup_rpc_system.rst @@ -76,7 +76,7 @@ In our community, there is multiple RPC server implementations, e.g., ``apps/and RPC server need to be run on device machine, and it usually will depend on xPU driver, the enhanced TVM runtime with xPU support, and other libraries, so please setup the dependent components first, e.g., install the KMD driver, ensure the required dynamic libraries can be found from environment variable ``LD_LIBRARY_PATH``. -If the required compilation environment can be setup on your device machine, i.e., you needn't to do the cross compilation, then just follow the instruction of ``_ to compile the TVM runtime and directly jump to the step :ref:`luanch-rpc-server`. +If the required compilation environment can be setup on your device machine, i.e., you needn't to do the cross compilation, then just follow the instruction of ``_ to compile the TVM runtime and directly jump to the step :ref:`launch-rpc-server`. 1. Cross Compile TVM Runtime ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -134,9 +134,9 @@ Then copy the compress package ``tvm_runtime.tar.gz`` to your concrete device ma $ export PYTHONPATH=`pwd`/python:${PYTHONPATH} -.. _luanch-rpc-server: +.. _launch-rpc-server: -3. Luanch RPC Server +3. Launch RPC Server ^^^^^^^^^^^^^^^^^^^^ The RPC server can be launched on your device machine through the commands like something below, please modify the *RPC_TRACKER_IP*, *RPC_TRACKER_PORT*, *RPC_PROXY_IP*, *RPC_PROXY_PORT*, and *RPC_KEY* according to your concrete environment. diff --git a/docs/how_to/index.rst b/docs/how_to/index.rst index 433d7acee95a..976b2f1bd4ba 100644 --- a/docs/how_to/index.rst +++ b/docs/how_to/index.rst @@ -15,25 +15,9 @@ specific language governing permissions and limitations under the License. -How To Guides -============= - -These user-focused "how to" guides are designed to help you find answers to -specific questions, like "How do I compile a model?" or "How to I optimize a -schedule with tesor expressions?" - .. toctree:: :maxdepth: 1 - compile_models/index - deploy/index - work_with_relay/index - work_with_schedules/index - optimize_operators/index - tune_with_autotvm/index - tune_with_autoscheduler/index - work_with_microtvm/index - extend_tvm/index - profile/index - ../errors - ../faq + tutorials/e2e_opt_model + tutorials/cross_compilation_and_rpc + dev/index diff --git a/docs/how_to/legacy_index.rst b/docs/how_to/legacy_index.rst new file mode 100644 index 000000000000..a98e04c96978 --- /dev/null +++ b/docs/how_to/legacy_index.rst @@ -0,0 +1,38 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +How To Guides +============= + +These user-focused "how to" guides are designed to help you find answers to +specific questions, like "How do I compile a model?" or "How to I optimize a +schedule with tesor expressions?" + +.. toctree:: + :maxdepth: 1 + + compile_models/index + deploy/index + work_with_relay/index + work_with_schedules/index + optimize_operators/index + tune_with_autotvm/index + tune_with_autoscheduler/index + work_with_microtvm/index + extend_tvm/index + profile/index + ../faq diff --git a/docs/how_to/tutorials/README.txt b/docs/how_to/tutorials/README.txt new file mode 100644 index 000000000000..9cec77e7b624 --- /dev/null +++ b/docs/how_to/tutorials/README.txt @@ -0,0 +1,2 @@ +HOW TO +------ diff --git a/gallery/tutorial/cross_compilation_and_rpc.py b/docs/how_to/tutorials/cross_compilation_and_rpc.py similarity index 100% rename from gallery/tutorial/cross_compilation_and_rpc.py rename to docs/how_to/tutorials/cross_compilation_and_rpc.py diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py new file mode 100644 index 000000000000..a139e75cfe6a --- /dev/null +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _optimize_model: + +End-to-End Optimize Model +========================= +This tutorial demonstrates how to optimize a machine learning model using Apache TVM. We will +use a pre-trained ResNet-18 model from PyTorch and end-to-end optimize it using TVM's Relax API. +Please note that default end-to-end optimization may not suit complex models. +""" + +###################################################################### +# Preparation +# ----------- +# First, we prepare the model and input information. We use a pre-trained ResNet-18 model from +# PyTorch. + +import os +import sys +import numpy as np +import torch +from torch import fx +from torchvision.models.resnet import ResNet18_Weights, resnet18 + +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) + +###################################################################### +# Review Overall Flow +# ------------------- +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg +# :align: center +# :width: 80% +# +# The overall flow consists of the following steps: +# +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains +# all the information needed for compilation, including high-level Relax functions for +# computational graph, and low-level TensorIR functions for tensor program. +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, +# such as graph optimizations, tensor program optimizations, and library dispatching. +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. +# + + +###################################################################### +# Convert the model to IRModule +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further +# optimization. Besides the model, we also need to provide the input shape and data type. + +import tvm +from tvm import relax +from tvm.relax.frontend.torch import from_fx + +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) + +# Give the input shape and data type +input_info = [((1, 3, 224, 224), "float32")] + +# Convert the model to IRModule +with torch.no_grad(): + torch_fx_model = fx.symbolic_trace(torch_model) + mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + +mod, params = relax.frontend.detach_params(mod) +mod.show() + +###################################################################### +# IRModule Optimization +# --------------------- +# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered +# around IRModule optimization can be composed with existing pipelines. Note that each +# transformation can be combined as an optimization pipeline via ``tvm.ir.transform.Sequential``. +# +# In this tutorial, we focus on the end-to-end optimization of the model via auto-tuning. We +# leverage MetaSchedule to tune the model and store the tuning logs to the database. We also +# apply the database to the model to get the best performance. +# + +TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed +target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device +work_dir = "tuning_logs" + +# Skip running in CI environment +IS_IN_CI = os.getenv("CI", "") == "true" +if IS_IN_CI: + sys.exit(0) + +with target: + mod = tvm.ir.transform.Sequential( + [ + # Convert BatchNorm into a sequence of simpler ops for fusion + relax.transform.DecomposeOpsForInference(), + # Canonicalize the bindings + relax.transform.CanonicalizeBindings(), + # Run default optimization pipeline + relax.get_pipeline("zero"), + # Tune the model and store the log to database + relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS), + # Apply the database + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + +# Only show the main function +mod["main"].show() + +###################################################################### +# Build and Deploy +# ---------------- +# Finally, we build the optimized model and deploy it to the target device. + +ex = relax.build(mod, target="cuda") +dev = tvm.device("cuda", 0) +vm = relax.VirtualMachine(ex, dev) +# Need to allocate data and params on GPU device +gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) +gpu_params = [tvm.nd.array(p, dev) for p in params["main"]] +gpu_out = vm["main"](gpu_data, *gpu_params).numpy() + +print(gpu_out.shape) diff --git a/docs/index.rst b/docs/index.rst index 07022cdef7ae..fdfaa56f7454 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -36,22 +36,13 @@ driving its costs down. install/index get_started/tutorials/quick_start get_started/tutorials/ir_module - contribute/index .. toctree:: :maxdepth: 1 - :caption: User Guide + :caption: How To - tutorial/index how_to/index -.. toctree:: - :maxdepth: 1 - :caption: Developer Guide - - dev/tutorial/index - dev/how_to/how_to.rst - .. toctree:: :maxdepth: 1 :caption: API Reference @@ -63,6 +54,10 @@ driving its costs down. :maxdepth: 1 :caption: Legacy + tutorial/index + how_to/legacy_index + dev/tutorial/index + dev/how_to/how_to.rst reference/langref/index arch/index topic/microtvm/index @@ -72,6 +67,7 @@ driving its costs down. :maxdepth: 1 :caption: About + contribute/index reference/publications reference/security diff --git a/gallery/tutorial/install.py b/gallery/tutorial/install.py deleted file mode 100644 index 0eb3ccc94c06..000000000000 --- a/gallery/tutorial/install.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Installing TVM -============== -**Authors**: -`Jocelyn Shiue `_, -`Chris Hoge `_ - -Depending on your needs and your working environment, there are a few different -methods for installing TVM. These include: - -* Installing from source -* Installing from third-party binary package. -""" - -################################################################################ -# Installing From Source -# ---------------------- -# Installing from source is the recommended method for installing TVM. It will -# allow you to enable specific features such as GPU support, microcontroller -# support (microTVM), and a debugging runtime, and other features. You will also -# want to install from source if you want to actively contribute to the TVM -# project. The full instructions are on the :ref:`Install TVM From Source -# ` page. - -################################################################################ -# Installing From Binary Packages -# -------------------------------- -# You may install convenient third party binary package distributions to -# quickly try things out. TLCPack is a third party volunteer community that -# builds binary packages from TVM source. It offers a support matrix with -# instructions to install on different platforms, with different features. -# Check out `TLCPack `_ to learn more. Note that the -# third party binary packages could contain additional licensing terms for -# the hardware drivers that are bundled with it. diff --git a/gallery/tutorial/introduction.py b/gallery/tutorial/introduction.py index 8d1f0e2699b2..4b94b23cf944 100644 --- a/gallery/tutorial/introduction.py +++ b/gallery/tutorial/introduction.py @@ -35,13 +35,11 @@ -------- #. :doc:`Introduction ` -#. :doc:`Installing TVM ` #. :doc:`Compiling and Optimizing a Model with the Command Line Interface ` #. :doc:`Compiling and Optimizing a Model with the Python Interface ` #. :doc:`Working with Operators Using Tensor Expression ` #. :doc:`Optimizing Operators with Templates and AutoTVM ` #. :doc:`Optimizing Operators with Template-free AutoScheduler ` -#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) ` #. :doc:`Compiling Deep Learning Models for GPUs ` """ From 98de9ba8418ec70ed7da59b737c93bd1b9ab611a Mon Sep 17 00:00:00 2001 From: Yu Xuanchi Date: Thu, 29 Aug 2024 19:11:59 +0800 Subject: [PATCH 095/202] [TVM4J][BugFix] Fix unhandled return type in JNI (#17308) --- jvm/native/src/main/native/jni_helper_func.h | 1 + 1 file changed, 1 insertion(+) diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 82165e9e04b1..d60a1a4230b7 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -188,6 +188,7 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { switch (tcode) { case kDLUInt: case kDLInt: + case kTVMArgBool: return newTVMValueLong(env, static_cast(value.v_int64)); case kDLFloat: return newTVMValueDouble(env, static_cast(value.v_float64)); From 40b6c14bba2ae31d371644b33e261e4cbaaa5b54 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 1 Sep 2024 14:00:15 -0700 Subject: [PATCH 096/202] [Disco] Add NVSHMEM support (#17317) This PR adds the supports of NVSHMEM. --- CMakeLists.txt | 23 +++++ cmake/modules/LibInfo.cmake | 1 + cmake/utils/FindNVSHMEM.cmake | 52 +++++++++++ src/runtime/contrib/nvshmem/nvshmem.cc | 66 ++++++++++++++ src/support/libinfo.cc | 5 ++ tests/python/disco/test_nvshmem.py | 114 +++++++++++++++++++++++++ 6 files changed, 261 insertions(+) create mode 100644 cmake/utils/FindNVSHMEM.cmake create mode 100644 src/runtime/contrib/nvshmem/nvshmem.cc create mode 100644 tests/python/disco/test_nvshmem.py diff --git a/CMakeLists.txt b/CMakeLists.txt index aa2a385683d7..38dd59b9c906 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ include(cmake/utils/FindLLVM.cmake) include(cmake/utils/FindROCM.cmake) include(cmake/utils/FindRCCL.cmake) include(cmake/utils/FindEthosN.cmake) +include(cmake/utils/FindNVSHMEM.cmake) if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) include(${CMAKE_BINARY_DIR}/config.cmake) @@ -133,6 +134,7 @@ tvm_option(USE_UMA "Build with UMA support" OFF) tvm_option(USE_VERILATOR "Build with Verilator support" OFF) tvm_option(USE_MSC "Enable Multi-System Compiler" OFF) tvm_option(USE_MRVL "Build with MRVL TVM support" OFF) +tvm_option(USE_NVSHMEM "Build with NVSHMEM support" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -472,6 +474,16 @@ if(USE_CUDA AND USE_NCCL) list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC}) endif() +if (USE_CUDA AND USE_NVSHMEM) + message(STATUS "Build with NVSHMEM...") + find_nvshmem(${USE_NVSHMEM}) + if (NOT NVSHMEM_FOUND) + message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM}) + endif() + tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS}) +endif() + if(USE_ROCM AND USE_RCCL) message(STATUS "Build with RCCL...") find_rccl(${USE_RCCL}) @@ -957,6 +969,17 @@ if(USE_CUDA AND USE_NCCL) target_link_libraries(tvm_runtime PRIVATE nccl ${LIBRT}) endif() + +if (USE_CUDA AND USE_NVSHMEM) + include_directories(SYSTEM ${USE_NVSHMEM}/include) + find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR}) + find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR}) + target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE}) + target_link_libraries(tvm_runtime PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE}) + set_target_properties(tvm PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + set_target_properties(tvm_runtime PROPERTIES CUDA_SEPARABLE_COMPILATION ON) +endif() + if(USE_ROCM AND USE_RCCL) target_link_libraries(tvm PRIVATE rccl) target_link_libraries(tvm_runtime PRIVATE rccl) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index da9bc3e1c9d3..a2b51bb33195 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -143,6 +143,7 @@ function(add_lib_info src_file) TVM_INFO_USE_VERILATOR="${USE_VERILATOR}" TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" + TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}" TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}" ) diff --git a/cmake/utils/FindNVSHMEM.cmake b/cmake/utils/FindNVSHMEM.cmake new file mode 100644 index 000000000000..1a833332a289 --- /dev/null +++ b/cmake/utils/FindNVSHMEM.cmake @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####################################################### +# Enhanced version of find NVSHMEM. +# +# Usage: +# find_nvshmem(${USE_NVSHMEM}) +# +# - When USE_NVSHMEM=ON, use auto search +# - When USE_NVSHMEM=/path/to/installed/nvshmem, use the installed nvshmem path. +# Can be useful when nvshmem is installed at specified location. +# +# Provide variables: +# +# - NVSHMEM_FOUND +# - NVSHMEM_INCLUDE_DIR +# - NVSHMEM_LIB_DIR +# + +macro(find_nvshmem use_nvshmem) + set(__use_nvshmem ${use_nvshmem}) + if(IS_DIRECTORY ${__use_nvshmem}) + set(__nvshmem_path ${__use_nvshmem}) + message(STATUS "Custom NVSHMEM PATH=" ${__use_nvshmem}) + elseif(IS_DIRECTORY $ENV{NVSHMEM_HOME}) + set(__nvshmem_path $ENV{NVSHMEM_HOME}) + else() + set(__nvshmem_path "") + endif() + + find_package(NVSHMEM HINTS ${__nvshmem_path}/lib/cmake/nvshmem/) + + if(NVSHMEM_FOUND) + message(STATUS "NVSHMEM_INCLUDE_DIR=" ${NVSHMEM_INCLUDE_DIR}) + message(STATUS "NVSHMEM_LIB_DIR=" ${NVSHMEM_LIB_DIR}) + endif(NVSHMEM_FOUND) +endmacro(find_nvshmem) diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc b/src/runtime/contrib/nvshmem/nvshmem.cc new file mode 100644 index 000000000000..985ba5510762 --- /dev/null +++ b/src/runtime/contrib/nvshmem/nvshmem.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +namespace tvm { +namespace runtime { + +ShapeTuple InitNVSHMEMUID() { + nvshmemx_uniqueid_t uid; + nvshmemx_get_uniqueid(&uid); + std::vector uid_64; + uid_64.push_back(static_cast(uid.version)); + for (int i = 0; i < UNIQUEID_PADDING; ++i) { + uid_64.push_back(static_cast(uid.internal[i])); + } + return ShapeTuple(uid_64); +} + +void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + ICHECK(worker != nullptr); + CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1) + << "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got " + << uid_64.size() << "."; + + nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; + + nvshmemx_uniqueid_t uid; + uid.version = static_cast(uid_64[0]); + for (int i = 0; i < UNIQUEID_PADDING; ++i) { + uid.internal[i] = static_cast(uid_64[i + 1]); + } + nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " + << ", npes=" << nvshmem_n_pes(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); + +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 984a2f3323ad..73800338b143 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -275,6 +275,10 @@ #define TVM_INFO_USE_CCACHE "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_NVSHMEM +#define TVM_INFO_USE_NVSHMEM "NOT-FOUND" +#endif + namespace tvm { /*! @@ -387,6 +391,7 @@ TVM_DLL Map GetLibInfo() { {"USE_VERILATOR", TVM_INFO_USE_VERILATOR}, {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, + {"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM}, {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT}, }; return result; diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py new file mode 100644 index 000000000000..0b16fe93612f --- /dev/null +++ b/tests/python/disco/test_nvshmem.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Basic tests for a Disco nvshmem support""" +# pylint: disable=missing-docstring +import tempfile + +import numpy as np +import pytest +import subprocess +import threading +import sys + +import tvm +import tvm.testing +from tvm.runtime import ShapeTuple +from tvm.runtime import disco as di +from tvm.exec import disco_worker as _ # pylint: disable=unused-import + +_SOCKET_SESSION_TESTER = None + + +def get_free_port(): + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class SocketSessionTester: + def __init__(self, num_workers): + num_nodes = 2 + num_groups = 1 + assert num_workers % num_nodes == 0 + num_workers_per_node = num_workers // num_nodes + server_host = "localhost" + server_port = get_free_port() + self.sess = None + + def start_server(): + self.sess = di.SocketSession( + num_nodes, num_workers_per_node, num_groups, server_host, server_port + ) + + thread = threading.Thread(target=start_server) + thread.start() + + cmd = "tvm.exec.disco_remote_socket_session" + self.remote_nodes = [] + for _ in range(num_nodes - 1): + self.remote_nodes.append( + subprocess.Popen( + [ + "python3", + "-m", + cmd, + server_host, + str(server_port), + str(num_workers_per_node), + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + ) + + thread.join() + + def __del__(self): + for node in self.remote_nodes: + node.kill() + if self.sess is not None: + self.sess.shutdown() + del self.sess + + +def create_socket_session(num_workers): + global _SOCKET_SESSION_TESTER + if _SOCKET_SESSION_TESTER is not None: + del _SOCKET_SESSION_TESTER + _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers) + assert _SOCKET_SESSION_TESTER.sess is not None + return _SOCKET_SESSION_TESTER.sess + + +@pytest.mark.parametrize("num_workers", [2, 4]) +def test_nvshmem_init(num_workers): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + sess = create_socket_session(num_workers=num_workers) + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers) + sess.sync_worker_0() + + +if __name__ == "__main__": + tvm.testing.main() From 3262f19e6f7a6f58dc643e2585f196ef91c6bdab Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 2 Sep 2024 14:06:28 +0800 Subject: [PATCH 097/202] [Doc] Fix doc build error in e2e_opt_model.py (#17319) The `sys.exit` may stop the whole sphinx build process, but not the single script execution. --- docs/how_to/tutorials/e2e_opt_model.py | 63 +++++++++++++------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index a139e75cfe6a..0053d309d5a9 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -32,7 +32,6 @@ # PyTorch. import os -import sys import numpy as np import torch from torch import fx @@ -101,39 +100,39 @@ # Skip running in CI environment IS_IN_CI = os.getenv("CI", "") == "true" -if IS_IN_CI: - sys.exit(0) - -with target: - mod = tvm.ir.transform.Sequential( - [ - # Convert BatchNorm into a sequence of simpler ops for fusion - relax.transform.DecomposeOpsForInference(), - # Canonicalize the bindings - relax.transform.CanonicalizeBindings(), - # Run default optimization pipeline - relax.get_pipeline("zero"), - # Tune the model and store the log to database - relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS), - # Apply the database - relax.transform.MetaScheduleApplyDatabase(work_dir), - ] - )(mod) - -# Only show the main function -mod["main"].show() +if not IS_IN_CI: + with target: + mod = tvm.ir.transform.Sequential( + [ + # Convert BatchNorm into a sequence of simpler ops for fusion + relax.transform.DecomposeOpsForInference(), + # Canonicalize the bindings + relax.transform.CanonicalizeBindings(), + # Run default optimization pipeline + relax.get_pipeline("zero"), + # Tune the model and store the log to database + relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS), + # Apply the database + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + + # Only show the main function + mod["main"].show() ###################################################################### # Build and Deploy # ---------------- # Finally, we build the optimized model and deploy it to the target device. - -ex = relax.build(mod, target="cuda") -dev = tvm.device("cuda", 0) -vm = relax.VirtualMachine(ex, dev) -# Need to allocate data and params on GPU device -gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) -gpu_params = [tvm.nd.array(p, dev) for p in params["main"]] -gpu_out = vm["main"](gpu_data, *gpu_params).numpy() - -print(gpu_out.shape) +# We skip this step in the CI environment. + +if not IS_IN_CI: + ex = relax.build(mod, target="cuda") + dev = tvm.device("cuda", 0) + vm = relax.VirtualMachine(ex, dev) + # Need to allocate data and params on GPU device + gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) + gpu_params = [tvm.nd.array(p, dev) for p in params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params).numpy() + + print(gpu_out.shape) From cd3448603dffea2340e406dd7751a37b0440d81f Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 2 Sep 2024 14:06:37 +0800 Subject: [PATCH 098/202] [Doc] Customize Optimization (#17320) [Doc] Customization Optimization --- docs/how_to/index.rst | 1 + docs/how_to/tutorials/customize_opt.py | 225 +++++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 docs/how_to/tutorials/customize_opt.py diff --git a/docs/how_to/index.rst b/docs/how_to/index.rst index 976b2f1bd4ba..c5b9d703f032 100644 --- a/docs/how_to/index.rst +++ b/docs/how_to/index.rst @@ -19,5 +19,6 @@ :maxdepth: 1 tutorials/e2e_opt_model + tutorials/customize_opt tutorials/cross_compilation_and_rpc dev/index diff --git a/docs/how_to/tutorials/customize_opt.py b/docs/how_to/tutorials/customize_opt.py new file mode 100644 index 000000000000..5806d6ce5da1 --- /dev/null +++ b/docs/how_to/tutorials/customize_opt.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _customize_opt: + +Customize Optimization +====================== +One main design goal of Apache TVM is to enable easy customization of the optimization pipeline +for both research or development purposes and iterate the engineering optimizations. In this +tutorial we will + +.. contents:: Table of Contents + :local: + :depth: 1 +""" + +###################################################################### +# Review Overall Flow +# ------------------- +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg +# :align: center +# :width: 80% +# +# The overall flow consists of the following steps: +# +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains +# all the information needed for compilation, including high-level Relax functions for +# computational graph, and low-level TensorIR functions for tensor program. +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, +# such as graph optimizations, tensor program optimizations, and library dispatching. +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. +# + +import os +import tempfile +import numpy as np +import tvm +from tvm import IRModule, relax +from tvm.relax.frontend import nn + +###################################################################### +# Composable IRModule Optimization +# -------------------------------- +# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered +# around IRModule optimization can be composed with existing pipelines. Note that each optimization +# can focus on **part of the computation graph**, enabling partial lowering or partial optimization. +# +# In this tutorial, we will demonstrate how to optimize a model with Apache TVM Unity. + +###################################################################### +# Prepare a Relax Module +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# We first prepare a Relax module. The module can be imported from other frameworks, constructed +# with NN module frontend or TVMScript. Here we use a simple neural network model as an example. + + +class RelaxModel(nn.Module): + def __init__(self): + super(RelaxModel, self).__init__() + self.fc1 = nn.Linear(784, 256) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(256, 10, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +input_shape = (1, 784) +mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}) +mod.show() + +###################################################################### +# Library Dispatch +# ~~~~~~~~~~~~~~~~ +# We would like to quickly try out a variant of library optimization for certain platforms +# (e.g., GPU). We can write a certain dispatching pass for the specific platform and +# operator. Here we demonstrate how to dispatch the CUBLAS library for certain patterns. +# +# .. note:: +# This tutorial only demonstrates a single operator dispatching for CUBLAS, highlighting +# the flexibility of the optimization pipeline. In real-world cases, we can import multiple +# patterns and dispatch them to different kernels. + + +# Import cublas pattern +import tvm.relax.backend.contrib.cublas as _cublas + + +# Define a new pass for CUBLAS dispatch +@tvm.transform.module_pass(opt_level=0, name="CublasDispatch") +class CublasDispatch: + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + # Check if CUBLAS is enabled + if not tvm.get_global_func("relax.ext.cublas", True): + raise Exception("CUBLAS is not enabled.") + + # Get interested patterns + patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")] + # Note in real-world cases, we usually get all patterns + # patterns = relax.backend.get_patterns_with_prefix("cublas") + + # Fuse ops by patterns and then run codegen + mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod) + mod = relax.transform.RunCodegen()(mod) + return mod + + +mod = CublasDispatch()(mod) +mod.show() + +###################################################################### +# After the dispatching pass, we can see that the first ``nn.Linear`` and ``nn.ReLU`` are fused +# and rewritten to a ``call_dps_packed`` function which call the CUBLAS library. Notably, the +# other part is not changed, which means we can selectively dispatch the optimization for +# certain computation. + +###################################################################### +# Auto Tuning +# ~~~~~~~~~~~ +# Continuing from the previous example, we can further optimize the model with auto-tuning for +# the **rest part of the computation**. Here we demonstrate how to use the meta-schedule to auto-tune +# the model. +# +# We can use ``MetaScheduleTuneTIR`` pass to simply tuning the model, while ``MetaScheduleApplyDatabase`` +# pass to apply the best configuration to the model. The tuning process will generate search space, +# tune the model and the following steps will apply the best configuration to the model. Before +# running the passes, we need to lowering relax operator into TensorIR functions via ``LegalizeOps`` +# +# .. note:: +# +# To save CI time and avoid flakiness, we skip the tuning process in CI environment. +# + +device = tvm.cuda(0) +target = tvm.target.Target.from_device(device) +if os.getenv("CI", "") != "true": + trials = 2000 + with target, tempfile.TemporaryDirectory() as tmp_dir: + mod = tvm.ir.transform.Sequential( + [ + relax.get_pipeline("zero"), + relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials), + relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir), + ] + )(mod) + + mod.show() + +###################################################################### +# DLight Rules +# ~~~~~~~~~~~~ +# DLight rules are a set of default rules for scheduling and optimization the kernel. +# DLight rules are designed for fast compilation and **fair** performance. In some cases, +# e.g. language model, DLight provides excellent performance, while for generic models, +# it achieves a balance between performance and compilation time. + +from tvm import dlight as dl + +# Apply DLight rules +with target: + mod = tvm.ir.transform.Sequential( + [ + relax.get_pipeline("zero"), + dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + ] + )(mod) + +mod.show() + +###################################################################### +# .. note:: +# +# This tutorial focuses on the demonstration of the optimization pipeline, instead of +# pushing the performance to the limit. The current optimization may not be the best. + + +###################################################################### +# Deploy the Optimized Model +# -------------------------- +# We can build and deploy the optimized model to the TVM runtime. + +ex = relax.build(mod, target="cuda") +dev = tvm.device("cuda", 0) +vm = relax.VirtualMachine(ex, dev) +# Need to allocate data and params on GPU device +data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev) +gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params] +gpu_out = vm["forward"](data, *gpu_params).numpy() +print(gpu_out) + + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates how to customize the optimization pipeline for ML models in Apache TVM. +# We can easily compose the optimization passes and customize the optimization for different parts +# of the computation graph. The flexibility of the optimization pipeline enables us to quickly +# iterate the optimization and improve the performance of the model. +# From 35e74cc4c9c8dec658217ffeea85f2ba25e35a35 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 3 Sep 2024 01:06:43 +0900 Subject: [PATCH 099/202] [Fix] Remove `tvm.` prefix from image name when `./docker/build.sh` (#17324) remove `tvm.` prefix --- docker/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/README.md b/docker/README.md index c311e86d190a..acebf923b4c0 100644 --- a/docker/README.md +++ b/docker/README.md @@ -110,7 +110,7 @@ tasks. - lint the python codes ```bash - ./docker/build.sh tvm.ci_lint make pylint + ./docker/build.sh ci_lint make pylint ``` - build codes with CUDA support From b06df8464ebd7e785a6dafc440231b0e06c90407 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Sep 2024 08:15:26 -0500 Subject: [PATCH 100/202] [Relax][Transform] Compose preproc functions in LiftTransformParams (#17314) The `LiftTransformParams` pass produces additional functions, either named `$FOO_transform_params` when generating one transformation function per inference function, or `transform_params` when generating a single shared transformation function. Prior to this commit, if the `IRModule` already contained a function with that name, an error would be raised. After this commit, the `LiftTransformParams` pass will instead check for existing functions, and compose the previous transformation function with the newly-lifted transformation. This allows `LiftTransformParams` to be used alongside a hand-written parameter transformation. Closes https://github.com/apache/tvm/issues/17200 --- src/relax/transform/lift_transform_params.cc | 39 ++++-- src/relax/transform/utils.cc | 51 +++++++ src/relax/transform/utils.h | 14 ++ .../test_transform_lift_transform_params.py | 129 ++++++++++++------ 4 files changed, 184 insertions(+), 49 deletions(-) diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 937cb8702952..76df48430592 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -119,7 +119,10 @@ struct BaseCollectInfo { Function func(params, body, GetStructInfo(tuple_var)); func = WithAttr(func, attr::kNumInput, Integer(0)); func = CopyWithNewVars(func); + func = BundleModelParams(func); func = Downcast(CanonicalizeBindings(func)); + func = Downcast(RemoveAllUnused(func)); + return func; } }; @@ -725,11 +728,12 @@ std::vector> GetTargetFunctions( target_functions.push_back({gvar.value(), func.value()}); } } else { - // Get all the functions that have the `num_input` attribute. + // Get all the functions that have the `num_input` attribute, and + // are not already the result of `LiftTransformParams`. for (const auto& [gvar, func] : mod->functions) { if (func->IsInstance()) { auto opt_num_input = func->GetAttr(attr::kNumInput); - if (opt_num_input) { + if (opt_num_input && !ends_with(gvar->name_hint, "transform_params")) { target_functions.emplace_back(gvar, Downcast(func)); } } @@ -748,7 +752,6 @@ namespace transform { Pass PartitionTransformParams(Variant> shared_transform) { auto pass_func = [=](IRModule mod, PassContext pc) { - IRModule updates; std::optional global_collect_info; CHECK(shared_transform.defined()) << "shared_transform is not defined"; @@ -772,24 +775,41 @@ Pass PartitionTransformParams(Variant> shared_transform) { local_collect_info[gvar] = info; } + IRModule updated_runtime_functions; + for (const auto& [gvar, info] : local_collect_info) { auto new_runtime_func = info.MakeRuntimeFunction(); - updates->Add(gvar, new_runtime_func); + updated_runtime_functions->Add(gvar, new_runtime_func); } + Map lifted_transform_functions; if (global_collect_info.has_value()) { auto global_transform = global_collect_info.value().MakeCompileTimeFunc(); - updates->Add(GlobalVar("transform_params"), global_transform); + lifted_transform_functions.Set("transform_params", global_transform); } else { for (const auto& [gvar, info] : local_collect_info) { // transform_params is emitted for each function if global lifting is not enabled - updates->Add(GlobalVar(gvar->name_hint + "_transform_params"), - info.MakeCompileTimeFunction()); + lifted_transform_functions.Set(gvar->name_hint + "_transform_params", + info.MakeCompileTimeFunction()); } } - if (updates->functions.size()) { - mod.CopyOnWrite()->Update(updates); + if (updated_runtime_functions->functions.size() || lifted_transform_functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updated_runtime_functions); + + for (auto [name, transform] : lifted_transform_functions) { + if (auto opt = write_ptr->global_var_map_.Get(name)) { + auto old_gvar = opt.value(); + auto old_transform = Downcast(write_ptr->Lookup(old_gvar)); + write_ptr->Remove(old_gvar); + + transform = ComposeFunctions(old_transform, transform); + } + GlobalVar new_gvar(name); + UpdateStructInfo(new_gvar, GetStructInfo(transform)); + write_ptr->Add(new_gvar, transform); + } } return mod; @@ -817,7 +837,6 @@ Pass LiftTransformParams(Variant> shared_transform) { std::string func_name = gvar->name_hint; if (ends_with(func_name, "transform_params")) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); - func = BundleModelParams(func); if (pc->GetConfig(kLiftTransformConsumeParams).value_or(Bool(false))) { func = Downcast(ConsumeBundledParams()(func)); } diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index c0fde3bd4cb9..19e93bbc0c0e 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -19,6 +19,8 @@ #include "utils.h" +#include + namespace tvm { namespace relax { @@ -41,5 +43,54 @@ bool IsNestedTensor(const StructInfo& sinfo) { bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } +Function ComposeFunctions(Function func_a, Function func_b) { + Array bindings; + + Var func_a_output("func_a_output", func_a->ret_struct_info); + + bindings.push_back(VarBinding(func_a_output, func_a->body)); + + auto func_a_outputs = [&]() -> Array { + if (auto func_a_output_tuple = func_a->ret_struct_info.as()) { + Array outputs; + for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) { + outputs.push_back(TupleGetItem(func_a_output, i)); + } + return outputs; + } else { + return {func_a_output}; + } + }(); + + if (func_b->params.size() == 1 && func_b->params[0]->struct_info_.as()) { + // Special case where the output of the first function is a tuple + // that should be provided as-is to the second function, and + // should not be unpacked into individual elements. + auto param = func_b->params[0]; + bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param))); + } else { + CHECK_EQ(func_a_outputs.size(), func_b->params.size()) + << "ValueError: " + << "Cannot compose functions together. " + << "First function produces " << func_a_outputs.size() << " values, " + << "but second function expects " << func_b->params.size() << " parameters as input"; + for (size_t i = 0; i < func_a_outputs.size(); i++) { + auto param = func_b->params[i]; + bindings.push_back(MatchCast(param, func_a_outputs[i], GetStructInfo(param))); + } + } + + auto new_body = SeqExpr({BindingBlock(bindings)}, func_b->body); + + auto new_function = Function(func_a->params, new_body, func_b->ret_struct_info, + func_a->is_pure && func_b->is_pure, func_a->attrs); + + new_function = CopyWithNewVars(new_function); + new_function = Downcast(CanonicalizeBindings(new_function)); + new_function = Downcast(RemoveAllUnused(new_function)); + + return new_function; +} + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 932dca30a110..55e355b4bac2 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -437,6 +437,20 @@ Expr CanonicalizeBindings(Expr expr); */ Function BundleModelParams(const Function& func, Optional param_tuple_name = NullOpt); +/*! \brief Compose two functions + * + * Given two functions `func_a` and `func_b`, produce `func_c` such + * that `func_c(x)` is equivalent to `func_b(func_a(x))`. + * + * If the output if `func_a` is not usable as the input of `func_b`, + * an error will be raised. + * + * \param func_a The first function to be composed. + * \param func_b The second function to be composed. + * \return The composed function + */ +TVM_DLL Function ComposeFunctions(Function func_a, Function func_b); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 508664f1ef54..90f2050f7898 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -112,7 +112,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -185,7 +185,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -290,18 +290,15 @@ def main( @R.function def main_transform_params( - params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) + params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") ): R.func_attr({"num_input": 0}) with R.dataflow(): - lv = params[0] - lv0 = (lv,) - lv1 = (lv0,) - lv2 = params[0] - lv3 = params[0] - gv = (lv2, lv3) + l3 = params[0] + w1 = params[0] + gv = (w1, l3) R.output(gv) return gv @@ -340,24 +337,14 @@ def main_transform_params( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((), dtype="bool"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((), dtype="bool"), ): R.func_attr({"num_input": 0}) - with R.dataflow(): - lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] - lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] - lv2: R.Tensor((), dtype="bool") = params[2] - gv: R.Tuple( - R.Tensor((16, 16, 3, 3), dtype="float32"), - R.Tensor((16, 16, 3, 3), dtype="float32"), - R.Tensor((), dtype="bool"), - ) = (lv, lv1, lv2) - R.output(gv) - return gv + return params @R.function def main( @@ -434,7 +421,7 @@ def func1( @R.function def func1_transform_params( - params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -457,7 +444,7 @@ def func2( @R.function def func2_transform_params( - params: R.Tuple(R.Tensor((128, 256), dtype="float32")) + params: R.Tuple(R.Tensor((128, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 128), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -531,7 +518,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -769,7 +756,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -884,7 +871,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -979,7 +966,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1103,7 +1090,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1226,7 +1213,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1322,7 +1309,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1395,7 +1382,7 @@ def func1( @R.function def func1_transform_params( - params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1426,9 +1413,6 @@ class Expected: @R.function def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) - with R.dataflow(): - gv: R.Tuple = R.tuple() - R.output() # All instance of the empty tuple are normalized to be # in-line. return R.tuple() @@ -1492,9 +1476,6 @@ def zeros(var_T_full: T.handle): @R.function def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) - with R.dataflow(): - gv: R.Tuple = R.tuple() - R.output() return R.tuple() @R.function @@ -1579,7 +1560,7 @@ def main( @R.function def main_transform_params( - params: R.Tuple(R.Tensor([16, 16], "int32"), R.Shape(["slice_index"])) + params: R.Tuple(R.Tensor([16, 16], "int32"), R.Shape(["slice_index"])), ): R.func_attr({"num_input": 0}) slice_index = T.int64() @@ -1643,7 +1624,7 @@ def main_transform_params( params: R.Tuple( R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 3), dtype="float32") ): @@ -1821,5 +1802,75 @@ def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])): tvm.ir.assert_structural_equal(after, Expected) +@pytest.mark.parametrize("shared_transform", [True, False]) +def test_lift_transform_is_idempotent(shared_transform): + """Multiple applicates of LiftTransformParams are allowed""" + + @I.ir_module + class Module: + @R.function + def main( + state: R.Tensor(["batch_size", 4096], "float16"), + base_weights: R.Tensor([4096, 4096], "float16"), + lora_A: R.Tensor([4096, "lora_rank"], "float16"), + lora_B: R.Tensor(["lora_rank", 4096], "float16"), + ): + R.func_attr({"num_input": 1}) + folded_weights = base_weights + R.matmul(lora_A, lora_B) + output = R.matmul(state, folded_weights) + return output + + transform = relax.transform.LiftTransformParams(shared_transform=shared_transform) + + AfterOneRound = transform(Module) + assert len(AfterOneRound.functions) == 2 + + AfterTwoRounds = transform(AfterOneRound) + assert len(AfterTwoRounds.functions) == 2 + + tvm.ir.assert_structural_equal(AfterOneRound, AfterTwoRounds) + + +def test_lift_transform_when_one_already_exists(): + """If the module already contains `transform_params`, the + functions are composed together""" + + @I.ir_module + class Module: + @R.function + def main( + state: R.Tensor(["batch_size", 4096], "float16"), + base_weights: R.Tensor([4096, 4096], "float16"), + lora_A: R.Tensor([4096, "lora_rank"], "float16"), + lora_B: R.Tensor(["lora_rank", 4096], "float16"), + ): + R.func_attr({"num_input": 1}) + folded_weights = base_weights + R.matmul(lora_A, lora_B) + output = R.matmul(state, folded_weights) + return output + + @R.function + def main_transform_params( + model_params: R.Tuple( + R.Tensor([4096, 4096], "float16"), + R.Tensor([4096, "lora_rank"], "float16"), + R.Tensor(["lora_rank", 4096], "float16"), + ), + ): + R.func_attr({"num_input": 0}) + return model_params + + transform = relax.transform.LiftTransformParams(shared_transform=False) + after_lift_with_previous_identity_function = transform(Module) + + del Module["main_transform_params"] + after_lift_without_previous_identity_function = transform(Module) + + tvm.ir.assert_structural_equal( + after_lift_without_previous_identity_function, + after_lift_with_previous_identity_function, + ) + + if __name__ == "__main__": tvm.testing.main() From 42bffc31ff2aa14b18275f70a3d658156dbed2a2 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 3 Sep 2024 22:51:42 +0800 Subject: [PATCH 101/202] [Target] Refine equality check on TargetKind instances (#17321) refine target kind identity Co-authored-by: wrongtest --- src/target/target_kind.cc | 15 ++++++++++++++- tests/python/target/test_target_target.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index fced74c3a559..979b755af846 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -35,7 +35,20 @@ namespace tvm { -TVM_REGISTER_NODE_TYPE(TargetKindNode); +// helper to get internal dev function in objectref. +struct TargetKind2ObjectPtr : public ObjectRef { + static ObjectPtr Get(const TargetKind& kind) { return GetDataPtr(kind); } +}; + +TVM_REGISTER_NODE_TYPE(TargetKindNode) + .set_creator([](const std::string& name) { + auto kind = TargetKind::Get(name); + ICHECK(kind.defined()) << "Cannot find target kind \'" << name << '\''; + return TargetKind2ObjectPtr::Get(kind.value()); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index e977ef10aae0..1a52a46da1fc 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -559,5 +559,21 @@ def test_target_from_device_opencl(input_device): assert target.thread_warp_size == dev.warp_size +def test_module_dict_from_deserialized_targets(): + target = Target("llvm") + + from tvm.script import tir as T + + @T.prim_func + def func(): + T.evaluate(0) + + func = func.with_attr("Target", target) + target2 = tvm.ir.load_json(tvm.ir.save_json(target)) + mod = tvm.IRModule({"main": func}) + lib = tvm.build({target2: mod}, target_host=target) + lib["func"]() + + if __name__ == "__main__": tvm.testing.main() From 0e9c68303543e9b7e7a0146553aa0e81f63828f4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 4 Sep 2024 02:39:57 +0900 Subject: [PATCH 102/202] [Relax][PyTorch] Add support for `torch.nn.functional.conv*` (#17325) * add test for functional conv1d * add support for functional conv1d * cleanup conv1d * add test for functional conv_transpose1d * add support for functional conv_transpose1d * add test for functional conv_transpose2d * add support for functional conv_transpose2d * add test for functional conv3d * add support for functional conv3d --- .../tvm/relax/frontend/torch/fx_translator.py | 284 ++++++++++++++---- tests/python/relax/test_frontend_from_fx.py | 52 ++++ 2 files changed, 275 insertions(+), 61 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 676f63b5c359..245bb4cffb57 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -740,61 +740,140 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv1d(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: conv1d = self.block_builder.emit( relax.op.nn.conv1d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCW", kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv1d - - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv3d(self, node: fx.node.Node) -> relax.Var: + def _conv1d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCDHW", - kernel_layout="OIDHW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: - return conv3d + if bias is None: + return conv1d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) def _conv2d_impl( self, @@ -826,63 +905,142 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv2d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCW", - kernel_layout="OIW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", out_dtype="float32", ) ) - if module.bias is None: - return conv1d_transpose + if bias is None: + return conv2d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCHW", - kernel_layout="OIHW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", out_dtype="float32", ) ) - if module.bias is None: - return conv2d_transpose - - bias = self.params[module.bias] + if bias is None: + return conv3d assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv2d(self, node: fx.node.Node) -> relax.Var: + def _conv3d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -890,7 +1048,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: if module.bias is not None: bias = self.params[module.bias] - return self._conv2d_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -900,7 +1058,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -909,7 +1067,7 @@ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -1482,7 +1640,11 @@ def create_convert_map(self): "type": self._type, "astype": self._type, "matmul": self._matmul, + "conv1d": self._conv1d_functional, + "conv_transpose1d": self._conv1d_transpose_functional, "conv2d": self._conv2d_functional, + "conv_transpose2d": self._conv2d_transpose_functional, + "conv3d": self._conv3d_functional, "linear": self._linear_functional, "addmm": self._addmm, "baddbmm": self._baddbmm, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c6c4f2597260..e191775a63b2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -48,6 +48,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -113,6 +122,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv1D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -127,6 +140,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -192,6 +214,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -298,6 +324,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -363,6 +398,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -377,6 +416,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -442,6 +490,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv3D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) From 8059c770dc563411717a44d9409888be3f85b7ee Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 3 Sep 2024 11:39:26 -0700 Subject: [PATCH 103/202] [KVCache] Add tree attention with paged cache support (#17326) --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 3 +- python/tvm/relax/frontend/nn/llm/tree_attn.py | 536 +++++++++++++++++- src/runtime/relax_vm/paged_kv_cache.cc | 384 ++++++++----- ...me_builtin_paged_attention_kv_cache_tir.py | 76 ++- 4 files changed, 828 insertions(+), 171 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 5ddce76eab40..7b14c67a2e57 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -30,7 +30,7 @@ from tvm.target import Target from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func -from .tree_attn import tree_attn +from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache def get_max_num_threads_per_block(target: Target) -> int: @@ -257,6 +257,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, # fmt: on # pylint: enable=line-too-long diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 069eb4892348..9e4a7ed97e71 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -62,11 +62,29 @@ def _rope( return expr -def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): - return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) +def _check_tree_order(tree_order_indptr, tree_order, batch, row, col, kv_len, qo_len): + tree_order_len = tree_order_indptr[batch + 1] - tree_order_indptr[batch] + + tree_start = kv_len - tree_order_len + child_idx_in_tree = row + tree_order_len - qo_len + parent_idx_in_tree = col - tree_start + return tir.all( + col < kv_len, + tir.any( + col < tree_start, + tir.all( + tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] + >= tree_order[tree_order_indptr[batch] + parent_idx_in_tree, 0], + tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] + < tree_order[tree_order_indptr[batch] + parent_idx_in_tree, 1], + ), + ), + ) -def tree_attn(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): +def tree_attn( + h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target +): # pylint: disable=unused-argument """Generate tree attention kernel for batched tree attention. Parameters @@ -87,7 +105,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target) mod : tvm.IRModule The generated IR module. """ - # pylint: disable=line-too-long + # pylint: disable=invalid-name,line-too-long NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -140,7 +158,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) - mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) + mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable @@ -276,12 +294,13 @@ def batch_tree_attn( # pylint: disable=too-many-branches # mask out of kv_chunk_len S row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): - if _tree_mask( + if _check_tree_order( row=row_, col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + batch=b_idx, + tree_order=mask, + tree_order_indptr=mn_indptr, + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], kv_len=kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) @@ -293,12 +312,13 @@ def batch_tree_attn( # pylint: disable=too-many-branches # this is to avoid sync inside condition branch if row < tile_x: row_: T.int32 = (LH_start + row) // group_size - if _tree_mask( + if _check_tree_order( row=row_, col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + batch=b_idx, + tree_order=mask, + tree_order_indptr=mn_indptr, + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], kv_len=kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: @@ -415,3 +435,493 @@ def apply_to_md(sch, block): apply_to_md(sch, sch.get_block("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def tree_attn_with_paged_kv_cache( + h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target +): + """Generate tree attention kernel for batched tree attention with paged key-value cache. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=import-outside-toplevel + from .kv_cache import ( + _declare_length_info, + _get_kv_chunk_len, + _get_seq_offset, + check_thread_limits, + ) + + # pylint: disable=invalid-name, line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) + + global_symbol = "tree_attn_paged_kv" + sliding_window = False # Sliding window is not supported in this kernel. + + # fmt: off + @T.prim_func + def tree_attn_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + tree_order_indptr_handle: T.handle, # [batch_size + 1] + tree_order_handle: T.handle, # [total_len, 2] + ): + # pylint: disable=unused-variable, too-many-branches + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + tree_order_elem_offset = T.int32(is_size_var=True) + tree_order_indptr_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer( + var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset + ) + page_values = T.match_buffer( + var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer( + var_lse, (total_len, h_q), "float32" + ) # pylint: disable=unused-variable + tree_order_indptr = T.match_buffer( + tree_order_indptr_handle, + (batch_size + 1,), + "int32", + elem_offset=tree_order_indptr_elem_offset, + ) + total_tree_order_len = T.int32(is_size_var=True) + tree_order = T.match_buffer( + tree_order_handle, + (total_tree_order_len, 2), + "int32", + elem_offset=tree_order_elem_offset, + ) + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info( + var_length_info, batch_size, sliding_window, length_info_elem_offset + ) + + T.Assert( + rotary_mode == T.int32(0), "Inline rotary mode is not supported in tree attention." + ) + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x,), "float32", scope="shared") + + m_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + m_prev = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + d_new = T.alloc_buffer( + (math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local" + ) + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = ( + q_indptr[b_idx + 1] - q_indptr[b_idx] + ) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] + + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len( + cur_page_indptr_end - cur_page_indptr_begin, + 16, + b_idx, + length_info, + sliding_window, + ), + 0, + ) + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[cur_L], + d, + rope_theta, + rope_scale, + (cur_L, cur_H_qo, j), + dtype, + rope_scaling, + ), + q[cur_L, cur_H_qo, j], + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + for lz, ly in T.grid(tile_z, tile_y): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + K_smem[i, j] = pages[ + page_no, 0, by, page_offset, j + ] + else: + K_smem[i, j] = 0.0 + + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, tile_y): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + V_smem[i, j] = pages[ + page_no, 1, by, page_offset, j + ] + else: + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += ( + T.cast(Q_smem[i, k], "float32") + * T.cast(K_smem[j, k], "float32") + * attn_score_scaling_factor + * sm_scale + ) + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] + - q_indptr[b_idx], + ): + m_new[i] = T.max( + m_new[i], S_smem[row, j] + ) + d_new[i] = d_smem[row] * T.exp2( + m_prev[i] - m_new[i] + ) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = ( + LH_start + row + ) // group_size + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] + - q_indptr[b_idx], + ): + S_smem[row, j] = T.exp2( + S_smem[row, j] - m_new[i] + ) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2( + m_prev_smem[i] - m_smem[i] + ) + O_local[i, j] += S_smem[i, k] * T.cast( + V_smem[k, j], "float32" + ) + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = ( + q_indptr[b_idx] + (LH_start + i) // group_size + ) + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = ( + O_local[i, j] / d_smem[i] + ) + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = ( + q_indptr[b_idx] + (LH_start + i) // group_size + ) + cur_H_qo: T.int32 = ( + by * group_size + (LH_start + i) % group_size + ) + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(tree_attn_paged_kv) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 591187ab5fe7..8809a1b0729e 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -330,9 +330,9 @@ class PagedKVCacheAuxDataManager { */ virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the tree attention mask. */ - virtual NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) = 0; + virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the mn indptr of the tree attention mask. */ - virtual NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) = 0; + virtual NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ virtual void CommitAttnAuxDataCopy() = 0; @@ -379,14 +379,15 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device)); k_rope_pos_offset_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mask_device_.push_back(NDArray::Empty( + {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mn_indptr_device_.push_back( + NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); } cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - tree_attn_mask_device_ = NDArray::Empty( - {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device); - tree_attn_mn_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_src_dst_pos_in_page_table_device_ = @@ -450,15 +451,15 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = - tree_attn_mask_device_.CreateView({static_cast(data->size())}, dtype_aux_); + tree_attn_mask_device_[depth].CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { - NDArray view = - tree_attn_mn_indptr_device_.CreateView({static_cast(data->size())}, dtype_aux_); + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = tree_attn_mn_indptr_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } @@ -557,12 +558,12 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { std::vector page_indices_on_depths_device_; std::vector length_info_on_depths_device_; std::vector k_rope_pos_offset_on_depths_device_; + std::vector tree_attn_mask_device_; + std::vector tree_attn_mn_indptr_device_; NDArray cur_append_length_indptr_device_; NDArray k_ragged_rope_pos_offset_device_; NDArray q_rope_position_map_device_; NDArray append_position_map_device_; - NDArray tree_attn_mask_device_; - NDArray tree_attn_mn_indptr_device_; NDArray commit_copy_length_indptr_device_; NDArray commit_copy_src_dst_pos_in_page_table_device_; }; @@ -630,10 +631,11 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray mask_1d = CopyAttnAuxVecToCache(data); + return mask_1d.CreateView({static_cast(data->size() / 2), 2}, mask_1d->dtype); } - NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, @@ -894,7 +896,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The append lengths of the sequences in the current round of forwarding. */ IntTuple cur_append_lengths_; /*! \brief Whether the current batch of sequences are token chains (not token trees). */ - bool is_chain_; + std::vector is_chain_on_depths_; /*! \brief Number of fork depth in the current round of forward. */ int num_depths_; /*! \brief Whether to compute attention after appending KV into cache or not. */ @@ -930,8 +932,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; HostMemoryVector cur_append_lengths_indptr_host_; - HostMemoryVector tree_attn_mask_host_; - HostMemoryVector tree_attn_mn_indptr_host_; + std::vector tree_attn_mask_host_; + std::vector tree_attn_mn_indptr_host_; HostMemoryVector commit_copy_length_indptr_host_; HostMemoryVector commit_copy_src_pos_in_page_table_host_; HostMemoryVector commit_copy_dst_pos_in_page_table_host_; @@ -947,8 +949,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray k_ragged_rope_pos_offset_view_; NDArray q_rope_position_map_view_; NDArray append_position_map_view_; - NDArray tree_attn_mask_view_; - NDArray tree_attn_mn_indptr_view_; NDArray temp_attn_output_view_; NDArray temp_attn_scores_view_; NDArray merged_attn_scores_view_; @@ -957,6 +957,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector page_indices_on_depths_view_; std::vector length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; + std::vector tree_attn_mask_view_; + std::vector tree_attn_mn_indptr_view_; PackedFunc f_transpose_append_; PackedFunc f_compact_copy_; @@ -966,6 +968,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_attention_decode_sliding_window_; PackedFunc f_attention_prefill_ragged_; PackedFunc f_attention_prefill_with_tree_mask_; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv_; Optional f_attention_prefill_ragged_begin_forward_; Optional f_attention_prefill_ragged_end_forward_; Optional f_attention_prefill_begin_forward_; @@ -996,6 +999,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, + PackedFunc f_attention_prefill_with_tree_mask_paged_kv, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, @@ -1025,6 +1029,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)), f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)), f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)), + f_attention_prefill_with_tree_mask_paged_kv_( + std::move(f_attention_prefill_with_tree_mask_paged_kv)), f_attention_prefill_ragged_begin_forward_( std::move(f_attention_prefill_ragged_begin_forward)), f_attention_prefill_ragged_end_forward_(std::move(f_attention_prefill_ragged_end_forward)), @@ -1059,6 +1065,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); k_rope_pos_offset_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs, + dtype_aux_, preferred_host_device)); + tree_attn_mn_indptr_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); } k_ragged_rope_pos_offset_host_ = HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); @@ -1068,11 +1078,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); cur_append_lengths_indptr_host_ = HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); - tree_attn_mask_host_ = - HostMemoryVector(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs, - dtype_aux_, preferred_host_device); - tree_attn_mn_indptr_host_ = - HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); commit_copy_length_indptr_host_ = HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); commit_copy_src_pos_in_page_table_host_ = @@ -1092,6 +1097,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indices_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); + tree_attn_mask_view_.push_back(NDArray()); + tree_attn_mn_indptr_view_.push_back(NDArray()); + is_chain_on_depths_.push_back(true); } // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { @@ -1492,36 +1500,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequences.push_back(&it->second); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); - k_ragged_rope_pos_offset_host_.push_back(it->second.seq_length); + int k_rope_offset = it->second.seq_length; + if (!it->second.accepted_indices_committed) { + int tree_size = static_cast(it->second.token_tree_parent_ptr.size()); + k_rope_offset -= tree_size; + } + k_ragged_rope_pos_offset_host_.push_back(k_rope_offset); it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; } } - // - Check token tree validity and process the token tree. - is_chain_ = true; - tree_attn_mask_host_.clear(); - tree_attn_mn_indptr_host_.clear(); - if (opt_token_tree_parent_ptr.defined()) { - is_chain_ = ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value()); - } else { - // The input batch does not form trees. So each sequence in the batch - // is required to have all past accepted tokens committed. - for (int i = 0; i < cur_batch_size_; ++i) { - Sequence* sequence = sequences[i]; - CHECK(sequence->accepted_indices_committed) - << "The input batch does not form a tree, in which case the sequences in the input " - "batch are expected to have their accepted tokens token tree nodes committed. " - "Please invoke CommitAcceptedTokenTreeNodes for sequence " - << seq_ids[i]; - sequence->is_chain = true; - sequence->token_tree_parent_ptr.clear(); - sequence->token_tree_node_depths.clear(); - } - is_chain_ = true; - } - auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth(sequences); num_depths_ = std::min(static_cast(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth); @@ -1552,6 +1542,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false); } + bool has_previous_tree = + std::any_of(sequences.begin(), sequences.end(), + [](const Sequence* sequence) { return !sequence->accepted_indices_committed; }); + if (has_previous_tree) { + append_before_attn_ = true; + } + + // - Check token tree validity and process the token tree. + if (opt_token_tree_parent_ptr.defined()) { + CHECK(!support_sliding_window_) << "Tree attention does not support sliding window."; + CHECK(rope_mode_ != RoPEMode::kInline) << "Tree attention does not support inline RoPE mode."; + ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths, + trailing_blocks); + } else { + // The input batch does not form trees. So each sequence in the batch + // is required to have all past accepted tokens committed. + for (int i = 0; i < cur_batch_size_; ++i) { + Sequence* sequence = sequences[i]; + CHECK(sequence->accepted_indices_committed) + << "The input batch does not form a tree, in which case the sequences in the input " + "batch are expected to have their accepted tokens token tree nodes committed. " + "Please invoke CommitAcceptedTokenTreeNodes for sequence " + << seq_ids[i]; + sequence->is_chain = true; + sequence->token_tree_parent_ptr.clear(); + sequence->token_tree_node_depths.clear(); + } + std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); + } + if (append_before_attn_) { // Right now we use different kernels when depth is 1 or not 1. // For the case where maximum depth is 1, we create the auxiliary @@ -1656,9 +1676,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { - q_rope_position_map_host_.push_back( - k_ragged_rope_pos_offset_host_[i] + - (is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos])); + if (sequences[i]->token_tree_node_depths.empty()) { + q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos); + } else { + int64_t offset_in_tree = + static_cast(sequences[i]->token_tree_parent_ptr.size()) - append_length; + ICHECK_GE(offset_in_tree, 0); + q_rope_position_map_host_.push_back( + k_ragged_rope_pos_offset_host_[i] + + sequences[i]->token_tree_node_depths[offset_in_tree + pos]); + } int32_t pos_in_block = block.seq_length - append_length + pos; if (last_block_length_before_append[i] + pos < block.sink_length) { @@ -1763,12 +1790,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector sequences; sequences.reserve(num_seq_to_commit); + bool is_chain = true; for (int i = 0; i < num_seq_to_commit; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); - CHECK(!it->second.accepted_indices_committed) + is_chain = it->second.is_chain; + CHECK(leaf_indices[i] == -1 || !it->second.accepted_indices_committed) << "The accepted nodes of sequence " << seq_ids[i] << " are already committed."; CHECK_GE(leaf_indices[i], -1) << "Invalid tree index " << leaf_indices[i] << " which is less than -1"; @@ -1778,7 +1807,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << it->second.token_tree_parent_ptr.size() << " of the sequence"; } - if (!is_chain_) { + if (!is_chain) { commit_copy_length_indptr_host_.clear(); commit_copy_src_pos_in_page_table_host_.clear(); commit_copy_dst_pos_in_page_table_host_.clear(); @@ -1787,6 +1816,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int i = 0; i < num_seq_to_commit; ++i) { if (leaf_indices[i] == -1) { // No node is accepted. All nodes in the token tree need to be popped. + commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back()); continue; } @@ -1935,78 +1965,134 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return block_idx; } - bool ConstructTokenTreeMask(const std::vector& sequences, - const IntTuple& token_tree_parent_ptr) { - // We check if the token tree deteriorates to a chain, - // because chain cases can have simplified attention work flow. - bool is_chain = true; - int64_t sum_new_append_length = 0; - // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. - tree_attn_mn_indptr_host_.push_back(0); - ICHECK_EQ(sequences.size(), cur_batch_size_); - ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); - for (int i = 0; i < cur_batch_size_; ++i) { - int64_t append_length = cur_append_lengths_[i]; - // Update the token tree parent pointers. - sequences[i]->token_tree_parent_ptr = { - token_tree_parent_ptr->data + sum_new_append_length, - token_tree_parent_ptr->data + sum_new_append_length + cur_append_lengths_[i]}; - sum_new_append_length += cur_append_lengths_[i]; - - CHECK_LE(append_length, kTreeAttnMaxTreeSize) - << "The tree size is " << append_length << " which exceeds the maximum tree size limit " - << kTreeAttnMaxTreeSize; - tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() + - append_length * append_length); - } - CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length) - << "Invalid token tree size. The sum of \"append_lengths\" is " << sum_new_append_length - << " while there are " << token_tree_parent_ptr.size() - << " elements in \"token_tree_parent_ptr\"."; - - // - Construct the mask of each sequence. - for (int i = 0; i < cur_batch_size_; ++i) { - int64_t tree_size = sequences[i]->token_tree_parent_ptr.size(); - std::vector> mask; - std::vector depth; - mask.reserve(tree_size); - depth.reserve(tree_size); - sequences[i]->is_chain = true; - sequences[i]->accepted_indices_committed = false; - for (int64_t n = 0; n < tree_size; ++n) { - CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) - << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; - CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) - << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << sequences[i]->token_tree_parent_ptr[n]; - if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { - // The parent of the current node is not the last node. - // Therefore the tree is not a chain. - sequences[i]->is_chain = false; - is_chain = false; + void ConstructTokenTreeMask(const std::vector& sequences, + const IntTuple& token_tree_parent_ptr, + const std::vector>& block_ids_on_depths, + const std::vector>& trailing_blocks) { + // Check whether the token tree of a sequence should be handled at the current depth. + auto check_for_sequence = [&](int seq_i, int depth) -> bool { + if (!append_before_attn_) { + return true; + } + // Check if the last block of the sequence is on the current depth. + if (block_ids_on_depths[depth][seq_i] == sequences[seq_i]->last_block_idx || + (depth + 1 == kPagedKVCacheMaxBlockDepth && !trailing_blocks[seq_i].empty())) { + return true; + } + return false; + }; + for (int d = 0; d < num_depths_; ++d) { + // We check if the token tree deteriorates to a chain, + // because chain cases can have simplified attention work flow. + ICHECK_LT(d, tree_attn_mask_host_.size()); + ICHECK_LT(d, tree_attn_mn_indptr_host_.size()); + HostMemoryVector& tree_attn_mn_indptr = tree_attn_mn_indptr_host_[d]; + HostMemoryVector& tree_attn_mask = tree_attn_mask_host_[d]; + + std::vector seq_in_current_depth(cur_batch_size_, false); + + tree_attn_mn_indptr.clear(); + tree_attn_mask.clear(); + std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); + + bool is_chain = true; + // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. + tree_attn_mn_indptr.push_back(0); + ICHECK_EQ(sequences.size(), cur_batch_size_); + ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); + int64_t token_tree_parent_ptr_offset = 0; + for (int i = 0; i < cur_batch_size_; ++i) { + int64_t append_length = cur_append_lengths_[i]; + seq_in_current_depth[i] = check_for_sequence(i, d); + if (!seq_in_current_depth[i]) { + tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back()); + token_tree_parent_ptr_offset += append_length; // Skip the token tree of this sequence. + continue; + } + // Update the token tree parent pointers. + CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), + global_block_pool_[sequences[i]->last_block_idx].seq_length) + << "The token tree size is larger than the sequence length of the last block."; + std::copy(token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset, + token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset + append_length, + std::back_inserter(sequences[i]->token_tree_parent_ptr)); + token_tree_parent_ptr_offset += append_length; + + CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), kTreeAttnMaxTreeSize) + << "The tree size is " << append_length << " which exceeds the maximum tree size limit " + << kTreeAttnMaxTreeSize; + tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back() + + sequences[i]->token_tree_parent_ptr.size()); + } + CHECK_EQ(token_tree_parent_ptr.size(), token_tree_parent_ptr_offset) + << "Invalid token tree size. The sum of \"append_lengths\" is " + << token_tree_parent_ptr_offset << " while there are " << token_tree_parent_ptr.size() + << " elements in \"token_tree_parent_ptr\"."; + + // - Construct the mask of each sequence. + for (int i = 0; i < cur_batch_size_; ++i) { + if (!seq_in_current_depth[i]) { + continue; } + int64_t tree_size = sequences[i]->token_tree_parent_ptr.size(); + std::vector> mask; + std::vector depth; + mask.reserve(tree_size); + depth.reserve(tree_size); + sequences[i]->is_chain = true; + sequences[i]->accepted_indices_committed = false; + std::unordered_map> tree_parent_to_children; + std::vector tree_roots; + for (int n = 0; n < tree_size; ++n) { + CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; + CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << sequences[i]->token_tree_parent_ptr[n]; + if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { + // The parent of the current node is not the last node. + // Therefore the tree is not a chain. + sequences[i]->is_chain = false; + is_chain = false; + } + tree_parent_to_children[sequences[i]->token_tree_parent_ptr[n]].push_back(n); - std::vector single_pos_mask; - if (sequences[i]->token_tree_parent_ptr[n] != -1) { - // The current node has a parent in the token tree. - single_pos_mask = {mask[sequences[i]->token_tree_parent_ptr[n]].begin(), - mask[sequences[i]->token_tree_parent_ptr[n]].end()}; - depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1); - } else { - // The current node is root in the token tree. - single_pos_mask.resize(tree_size, /*value=*/0); - depth.push_back(0); + if (sequences[i]->token_tree_parent_ptr[n] != -1) { + depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1); + } else { + depth.push_back(0); + tree_roots.push_back(n); + } + } + std::vector> tree_order(tree_size); + int order = 0; + std::function tree_dfs = [&order, &tree_order, &tree_parent_to_children, + &tree_dfs](int node) -> int { + tree_order[node].first = order++; + int upper_bound = tree_order[node].first + 1; + for (int child : tree_parent_to_children[node]) { + upper_bound = std::max(upper_bound, tree_dfs(child)); + } + tree_order[node].second = upper_bound; + return upper_bound; + }; + for (auto root : tree_roots) { + tree_dfs(root); } - single_pos_mask[n] = 1; - mask.push_back(single_pos_mask); - for (int32_t mask_val : single_pos_mask) { - tree_attn_mask_host_.push_back(mask_val); + for (int n = 0; n < tree_size; ++n) { + tree_attn_mask.push_back(tree_order[n].first); + tree_attn_mask.push_back(tree_order[n].second); } + sequences[i]->token_tree_node_depths = std::move(depth); + } + + is_chain_on_depths_[d] = is_chain; + + if (!append_before_attn_) { + break; } - sequences[i]->token_tree_node_depths = std::move(depth); } - return is_chain; } /*! @@ -2236,13 +2322,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (!append_before_attn_) { - if (is_chain_) { + if (is_chain_on_depths_[0]) { f_attention_prefill_ragged_begin_forward_.value()( temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); - } else { - LOG(FATAL) << "Kernel BeginForward doesn't support tree attn."; } } for (int d = 0; d < num_depths_; ++d) { @@ -2285,7 +2369,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (!append_before_attn_) { // The first part of attention, which only involves the q and the newly appended k/v. is_first_kernel = false; - if (is_chain_) { + if (is_chain_on_depths_[0]) { // If the batch does not form a tree, use raggedness prefill kernel. f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, q_rope_position_map_view_, @@ -2296,14 +2380,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_theta_, attn_score_scaling_factor); } else { // The batch requires tree attention. - ICHECK(tree_attn_mask_view_.defined()); - ICHECK(tree_attn_mn_indptr_view_.defined()); ICHECK(f_attention_prefill_with_tree_mask_.defined()) << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; + ICHECK(tree_attn_mask_view_[0].defined()); + ICHECK(tree_attn_mn_indptr_view_[0].defined()); f_attention_prefill_with_tree_mask_( q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, - q_rope_position_map_view_, tree_attn_mn_indptr_view_, tree_attn_mask_view_, output, - merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, + q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0], + output, merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); } } @@ -2321,7 +2405,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_scores = temp_attn_scores_view_; } - if (use_decode_kernel_[d]) { + if (append_before_attn_ && !is_chain_on_depths_[d]) { + f_attention_prefill_with_tree_mask_paged_kv_( + /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, + attn_output, attn_scores, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d]); + } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], @@ -2446,13 +2538,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map_view_ = aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_); // 10. tree_attn_mask and tree_attn_mn_indptr - if (!is_chain_) { - tree_attn_mask_view_ = aux_data_manager_->CopyTreeAttnMaskAsync(&tree_attn_mask_host_); - tree_attn_mn_indptr_view_ = - aux_data_manager_->CopyTreeAttnMNIndptrAsync(&tree_attn_mn_indptr_host_); - } else { - tree_attn_mask_view_ = NDArray{nullptr}; - tree_attn_mn_indptr_view_ = NDArray{nullptr}; + for (int d = 0; d < num_depths_; ++d) { + if (!is_chain_on_depths_[d]) { + tree_attn_mask_view_[d] = + aux_data_manager_->CopyTreeAttnMaskOnDepthAsync(&tree_attn_mask_host_[d], d); + tree_attn_mn_indptr_view_[d] = + aux_data_manager_->CopyTreeAttnMNIndptrOnDepthAsync(&tree_attn_mn_indptr_host_[d], d); + } } // 11. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( @@ -2477,7 +2569,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 27 || args.size() == 28) + CHECK(args.size() == 28 || args.size() == 29) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2516,10 +2608,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") Optional f_debug_get_kv = args[24]; PackedFunc f_compact_copy = args[25]; PackedFunc f_attention_prefill_with_tree_mask = args[26]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27]; Optional rope_ext_factors = NullOpt; - if (args.size() >= 28 && args[27].IsObjectRef()) { - rope_ext_factors = args[27].AsObjectRef(); + if (args.size() >= 29 && args[28].IsObjectRef()) { + rope_ext_factors = args[28].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2542,6 +2635,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), + std::move(f_attention_prefill_with_tree_mask_paged_kv), std::move(f_attention_prefill_ragged_begin_forward), std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), @@ -2553,7 +2647,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 21 || args.size() == 22) + CHECK(args.size() == 22 || args.size() == 23) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2586,10 +2680,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") Optional f_debug_get_kv = args[18]; PackedFunc f_compact_copy = args[19]; PackedFunc f_attention_prefill_with_tree_mask = args[20]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21]; Optional rope_ext_factors = NullOpt; - if (args.size() >= 22 && args[21].IsObjectRef()) { - rope_ext_factors = args[21].AsObjectRef(); + if (args.size() >= 23 && args[22].IsObjectRef()) { + rope_ext_factors = args[22].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2611,8 +2706,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") std::move(f_attention_prefill), std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill_with_tree_mask), // - NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // + std::move(f_attention_prefill_with_tree_mask), // + std::move(f_attention_prefill_with_tree_mask_paged_kv), // + NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index c35b7062cdc2..5ab96caa9bc0 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -36,6 +36,7 @@ _merge_state_inplace, llama_rope_with_position_map, tree_attn, + tree_attn_with_paged_kv_cache, ) from tvm.runtime import ShapeTuple @@ -74,6 +75,7 @@ fattn_decode_sliding_window = None fattn_prefill_ragged = None fattn_prefill_with_tree_mask = None +fattn_prefill_with_tree_mask_paged_kv_cache = None fmerge_state = None fsplit_rotary = None fattention_rotary = None @@ -86,7 +88,7 @@ def set_global_func(head_dim, dtype): global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode - global fattn_prefill_ragged, fattn_prefill_with_tree_mask + global fattn_prefill_ragged, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache global fattn_prefill_sliding_window, fattn_decode_sliding_window global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy @@ -124,6 +126,9 @@ def set_global_func(head_dim, dtype): num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target ), tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target), + tree_attn_with_paged_kv_cache( + num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target + ), _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, rope_scaling @@ -146,6 +151,7 @@ def set_global_func(head_dim, dtype): fattn_decode_sliding_window, fattn_prefill_ragged, fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, fmerge_state, fsplit_rotary, fcopy_single_page, @@ -185,6 +191,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcopy_cache, fcompact_copy, fattn_prefill_with_tree_mask, + fattn_prefill_with_tree_mask_paged_kv_cache, None, ) return cache @@ -206,7 +213,7 @@ class RopeMode(enum.IntEnum): params=itertools.chain( itertools.product( [64, 128], - ["float16", "float32"], + ["float32", "float16"], [RopeMode.NORMAL], [False], ), @@ -296,23 +303,26 @@ def apply_attention( cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - assert (token_tree_parent_ptr_list is None) == (accepted_leaf_indices is None) flattened_token_tree_parent_ptr = None token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] if token_tree_parent_ptr_list: assert len(token_tree_node_depths_list) == len(seq_ids) - assert len(accepted_leaf_indices) == len(seq_ids) + if accepted_leaf_indices is not None: + assert len(accepted_leaf_indices) == len(seq_ids) flattened_token_tree_parent_ptr = [] for i, (token_tree_parent_ptr, append_length) in enumerate( zip(token_tree_parent_ptr_list, append_lengths) ): - assert len(token_tree_parent_ptr) == append_length - flattened_token_tree_parent_ptr += token_tree_parent_ptr + assert len(token_tree_parent_ptr) >= append_length + # parent pointer for the last `append_length` nodes (the new tokens) + append_token_tree_parent_ptr = token_tree_parent_ptr[-append_length:] + flattened_token_tree_parent_ptr += append_token_tree_parent_ptr token_tree_node_depths = [] for parent in token_tree_parent_ptr: token_tree_node_depths.append( 0 if parent == -1 else token_tree_node_depths[parent] + 1 ) + # depth of each node in the tree (this contains more than the last `append_length` nodes) token_tree_node_depths_list[i] = token_tree_node_depths fbegin_forward( @@ -337,6 +347,11 @@ def apply_attention( new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) q_array.append(new_q) + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length + assert prev_tree_size >= 0 + rope_offset -= prev_tree_size cached_k[seq_id] = np.concatenate( [ cached_k[seq_id], @@ -347,10 +362,12 @@ def apply_attention( if rope_mode != RopeMode.NORMAL else f_apply_rotary( new_k[l], - cached_k[seq_id].shape[1], + rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i], + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None, ) ) for l in range(num_layers) @@ -379,7 +396,11 @@ def apply_attention( for i, (seq_id, append_length) in enumerate(batch): assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length - rope_offset = cached_k[seq_id].shape[1] - append_length + rope_offset = cached_k[seq_id].shape[1] + if token_tree_parent_ptr_list is not None: + rope_offset -= len(token_tree_parent_ptr_list[i]) + else: + rope_offset -= append_length q_seq = ( q_array[i][layer_id] if rope_mode == RopeMode.NONE @@ -388,7 +409,9 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i], + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None, ) ).transpose(1, 0, 2) k_seq = ( @@ -422,15 +445,16 @@ def apply_attention( np.full_like(softmax_input, np.finfo("float32").max), k=length_diff ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) if token_tree_parent_ptr_list is not None: + tree_size = len(token_tree_parent_ptr_list[i]) tree_mask = np.full( - (append_length, append_length), np.finfo("float32").min, dtype="float32" + (tree_size, tree_size), np.finfo("float32").min, dtype="float32" ) for i, parent in enumerate(token_tree_parent_ptr_list[i]): if parent != -1: tree_mask[i] = tree_mask[parent] tree_mask[i, i] = np.finfo("float32").max tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) - mask[:, :, length_diff:] = tree_mask + mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :] softmax_input = np.minimum(softmax_input, mask) @@ -846,9 +870,12 @@ def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): @tvm.testing.requires_cuda def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config - if support_sliding_window and rope_mode == RopeMode.NORMAL: + if support_sliding_window: # Normal RoPE mode under sliding window settings is not supported. return + if rope_mode == RopeMode.INLINE: + # Inline RoPE mode is not supported for tree attention. + return fclear(kv_cache) cached_k = {} @@ -899,6 +926,29 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): for _ in range(5): apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + # Test the cases of tree attn with cached kv. + fclear(kv_cache) + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Do 5 rounds of tree decode. + num_seq = 4 + for i in range(5): + num_leaf_nodes = 2**i + parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)] + apply_attention( + kv_cache, + rope_mode, + [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)], + accepted_leaf_indices=( + None if i != 4 else [2, 6, -1, 4] + ), # Leaf nodes are committed all at once at the end. + ) + if __name__ == "__main__": HEAD_DIMS = [64, 128] From fd139c3dd7639843ac06e5664206a06458b8586f Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 4 Sep 2024 21:00:21 +0800 Subject: [PATCH 104/202] [Doc] How to Optimize a Language Model (#17327) This tutorial demonstrates how to optimize a language model using TVM. --- docs/conf.py | 1 - docs/how_to/index.rst | 24 - docs/how_to/tutorials/optimize_llm.py | 614 ++++++++++++++++++ docs/index.rst | 6 +- docs/legacy_redirect.py | 1 - .../how_to/work_with_schedules/intrin_math.py | 173 ----- 6 files changed, 619 insertions(+), 200 deletions(-) delete mode 100644 docs/how_to/index.rst create mode 100644 docs/how_to/tutorials/optimize_llm.py delete mode 100644 gallery/how_to/work_with_schedules/intrin_math.py diff --git a/docs/conf.py b/docs/conf.py index c933653233b1..1ffc4dcafdb2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -488,7 +488,6 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): "work_with_schedules": [ "schedule_primitives.py", "reduction.py", - "intrin_math.py", "scan.py", "extern_op.py", "tensorize.py", diff --git a/docs/how_to/index.rst b/docs/how_to/index.rst deleted file mode 100644 index c5b9d703f032..000000000000 --- a/docs/how_to/index.rst +++ /dev/null @@ -1,24 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. toctree:: - :maxdepth: 1 - - tutorials/e2e_opt_model - tutorials/customize_opt - tutorials/cross_compilation_and_rpc - dev/index diff --git a/docs/how_to/tutorials/optimize_llm.py b/docs/how_to/tutorials/optimize_llm.py new file mode 100644 index 000000000000..9311c0557fe7 --- /dev/null +++ b/docs/how_to/tutorials/optimize_llm.py @@ -0,0 +1,614 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _opt_llm: + +Optimize Large Language Model +============================= +As large language models (LLMs) have become a popular research topic in many different fields, +deploying them on cloud and edge devices has become a challenging task. In this tutorial, we will +demonstrate how to optimize a large language model using Apache TVM. We will use a pre-trained +TinyLlama model from Hugging Face and deploy it on various devices. +""" + +###################################################################### +# Review Overall Flow +# ------------------- +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg +# :align: center +# :width: 80% +# +# The overall flow consists of the following steps: +# +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains +# all the information needed for compilation, including high-level Relax functions for +# computational graph, and low-level TensorIR functions for tensor program. +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, +# such as graph optimizations, tensor program optimizations, and library dispatching. +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. +# + + +###################################################################### +# Construct the model architecture +# -------------------------------- +# We will use a pre-trained TinyLlama model from Hugging Face. However, usually we only load the +# pre-trained weight from Hugging Face but not the model architecture. We need to construct the +# model architecture by ourselves. Apache TVM prepares a PyTorch-liked API to construct the model +# architecture. We can use the API to construct the model architecture. + + +import dataclasses +import enum +import os +from pathlib import Path +from pprint import pprint +from typing import List, Optional + +import tvm +from tvm import dlight, relax, te, tir +from tvm.relax import register_pipeline +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op +from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache +from tvm.runtime import ShapeTuple + +###################################################################### +# First, we need to define the model configuration. The configuration includes the key parameters +# of the model, such as hidden size, intermediate size, etc. Here for convenience, we define a +# constant config specially for the TinyLlama model. + + +@dataclasses.dataclass +class LlamaConfig: + hidden_size: int = 2048 + intermediate_size: int = 5632 + num_attention_heads: int = 32 + num_hidden_layers: int = 22 + rms_norm_eps: float = 1e-05 + vocab_size: int = 32000 + rope_theta: int = 10000 + context_window_size: int = 2048 + prefill_chunk_size: int = 2048 + num_key_value_heads: int = 4 + head_dim: int = 64 # hidden_size // num_attention_heads + + +dev = tvm.device("cuda", 0) +target = tvm.target.Target.from_device(dev) + + +###################################################################### +# Next, we define the RoPE mode of the Paged KV cache. The RoPE mode is used to apply the +# Relative Positional Encoding (RoPE) to the query and key tensors. The RoPE mode can be set to +# `NONE`, `NORMAL`, or `INLINE`. If the RoPE mode is `NONE`, the KV cache will not apply RoPE to +# the query and key tensors. If the RoPE mode is `NORMAL`, RoPE will be applied to the key tensor +# before adding the key tensor to the cache. If the RoPE mode is `INLINE`, RoPE will be applied to +# the query and key tensors in the attention kernel on-the-fly. + + +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +###################################################################### +# Secondly, we define the model architecture. The model architecture consists of three parts: +# +# - Embedding layer: The embedding layer converts the input token IDs to the hidden states. +# - Decoder layers: The decoder layers are the core of the model. Each decoder layer consists of +# a self-attention layer and a feed-forward network (FFN) layer. +# - Output layer: The output layer converts the hidden states to the logits. +# +# First we define the FFN layer. Note that the following FFN layer is optimized implementation +# where we fuse the gate and up projection into one kernel. +# The naive implementation of FFN layer is: ``FFN(x) = down_proj(silu(gate(x)) * up(x))`` +# We could combine the ``gate`` and ``up`` projection into one kernel for better performance. +# The optimized implementation is: +# +# .. code-block:: python +# +# concat_x = gate_up(x) +# gate_x, up_x = split(concat_x, 2, axis=-1) +# FFN(x) = down_proj(silu(gate_x) * up_x) +# + + +class LlamaFFN(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * config.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +###################################################################### +# Then we define the self-attention layer. The self-attention layer consists of three parts: +# +# - QKV projection: The QKV projection converts the input hidden states to the query, key, and +# value tensors. +# - Attention: The attention layer computes the attention scores and applies the softmax +# operation. +# - Output projection: The output projection converts the attention output to the hidden states. +# +# We perform optimizations on the different parts of the self-attention layer: +# +# - QKV projection: We leverage the horizontal fusion on QKV projection and fuse them into one +# kernel. +# - Attention: We leverage the horizontal fusion on attention and fuse the QKV projection and + + +class LlamaAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: LlamaConfig): + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + # horizontal fusion on QKV projection + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + # Output Projection + return self.o_proj(output) + + +###################################################################### +# Finally, we define the model architecture with FFN and self-attention layers. + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + rms_norm_eps = config.rms_norm_eps + self.self_attn = LlamaAttention(config) + self.mlp = LlamaFFN(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + hidden_states += self.self_attn( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) + hidden_states += self.mlp(self.post_attention_layernorm(hidden_states)) + return hidden_states + + +class LlamaModel(nn.Module): + def __init__(self, config: LlamaConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaForCasualLM(nn.Module): + def __init__(self, config: LlamaConfig): + self.model = LlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.rope_theta + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def get_logits(self, hidden_states: Tensor): + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.get_logits(hidden_states) + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.get_logits(hidden_states) + return logits, paged_kv_cache + + def create_tir_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return TIRPagedKVCache( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=0, + layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]), + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + rope_scaling={}, + rope_ext_factors=relax.PrimValue(0), + rotary_dim=self.head_dim, + dtype=self.dtype, + target=target, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_tir_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + + +###################################################################### +# Export the model to Relax IRModule +# ---------------------------------- +# After defining the model architecture, we can export the model to the Relax IRModule. +# For demonstration, we only show the part of the model architecture. and parameters. + +model_config = LlamaConfig() +model = LlamaForCasualLM(model_config) +model.to("float16") +mod, named_params = model.export_tvm(spec=model.get_default_spec()) +prefill_str = mod["prefill"].script() +print(*prefill_str.split("\n")[3:20], sep="\n") # Only show the first 10 lines for demonstration +print(" ...") + +print("\nParameters:") +pprint(named_params[:5]) # Only show the first 5 parameters for demonstration + +###################################################################### +# Define Optimization Pipeline +# ---------------------------- +# We define a series of optimization passes to optimize the model. The optimization pipeline +# is designed specifically for the LLMs. + + +@register_pipeline("opt_llm") +def _pipeline( # pylint: disable=too-many-arguments + ext_mods: List[nn.ExternModule] = None, +): + ext_mods = ext_mods or [] + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + seq = tvm.transform.Sequential( + [ + # Phase 1. Passes on high-level operator graph + # We can enable cublas for further optimization + relax.transform.FuseTransposeMatmul(), + # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline + relax.transform.LegalizeOps(), + relax.transform.AnnotateTIROpPattern(), + relax.transform.FoldConstant(), + relax.transform.FuseOps(), + relax.transform.FuseTIR(), + # Phase 3. Passes on TIR + relax.transform.DeadCodeElimination(), + # Phase 4. Low-level Optimizations + dlight.ApplyDefaultSchedule( + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + ), + # Phase 5. Lowering to VM bytecode + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + relax.transform.StaticPlanBlockMemory(), + relax.transform.RewriteCUDAGraph(), + relax.transform.LowerAllocTensor(), + relax.transform.KillAfterLastUse(), + relax.transform.LowerRuntimeBuiltin(), + relax.transform.VMShapeLower(), + relax.transform.AttachGlobalSymbol(), + relax.transform.AttachExternModules(ext_mods), + ] + ) + mod = seq(mod) + return mod + + return _pipeline + + +with target: + ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm")) + vm = relax.VirtualMachine(ex, dev) + + +###################################################################### +# Prepare the model weights +# ------------------------- +# We load the pre-trained weights from Hugging Face and prepare the model weights. +# The pre-trained weights are stored in the Hugging Face format. We need to load the weights +# and prepare the model parameters. +# +# .. note:: +# +# Note that we won't execute the following code in this tutorial because the pre-trained weights +# are not available in the CI environment. +# + + +IS_IN_CI = os.getenv("CI", "") == "true" + +HF_WEIGHT_PATH = None +# HF_WEIGHT_PATH = Path("/path/to/TinyLlama-1.1B-Chat-v1.0/") + +if not IS_IN_CI: + import numpy as np + import safetensors.torch + import torch + + if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists(): + raise ValueError("Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.") + + # Torch format weights + param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / "model.safetensors", device="cpu") + # Numpy format weights + param_dict = { + k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy() + for k, v in param_dict.items() + } + + named_params = dict(named_params) + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + param_dict[f"{attn}.qkv_proj.weight"] = np.concatenate( + [ + param_dict.pop(f"{attn}.q_proj.weight"), # Pop the old parameters to save memory + param_dict.pop(f"{attn}.k_proj.weight"), + param_dict.pop(f"{attn}.v_proj.weight"), + ], + axis=0, + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + param_dict[f"{mlp}.gate_up_proj.weight"] = np.concatenate( + [ + param_dict.pop(f"{mlp}.gate_proj.weight"), + param_dict.pop(f"{mlp}.up_proj.weight"), + ], + axis=0, + ) + + # Convert params into ndarray + params = [ + tvm.nd.array(param_dict[k].astype("float16"), device=dev) for k in named_params.keys() + ] + + +###################################################################### +# Deploy the compiled model +# ------------------------- +# After the model and weights are ready, we can deploy the compiled model on the target device. +# The language models inference includes two steps: prefill and decode. The prefill step is +# used to process the input tokens and store the KVCache. The decode step is used to generate +# the token until the end token is generated. + + +###################################################################### +# Tokenization +# ~~~~~~~~~~~~ +# The first step is to tokenize the input prompt and embed the tokens into the hidden states. +# The tokenization and embedding are the same as the original model. We use the HF tokenizer +# to tokenize the input prompt and embed the tokens into the hidden states. +# Note that different models require different tokenization and prompt format, please refer to +# the model documentation for the correct tokenization and prompt format. + + +if not IS_IN_CI: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH) + messages = [ + {"role": "user", "content": "What's your name?"}, + ] + prompt = tokenizer.apply_chat_template(messages) + input_len = len(prompt) + + # Load prompt tokens into TVM ndarray on the target device + tokens = tvm.nd.array(np.array(prompt).astype("int32"), device=dev) + +###################################################################### +# Create the KVCache +# ~~~~~~~~~~~~~~~~~~ +# Before starting the inference, we need to create the KVCache. The KVCache is used to store the +# key and value tensors for the attention layer. Apache TVM provides a PagedKVCache to store the +# key and value tensors. We create the PagedKVCache with the specified parameters. + +if not IS_IN_CI: + kv_cache = vm["create_tir_paged_kv_cache"]( + ShapeTuple([1]), # max_batch_size=1 + ShapeTuple([2048]), # max_total_seq_len=2048 + ShapeTuple([2048]), # prefill_chunk_size=2048 + ShapeTuple([16]), # page_size=16 + ) + + +###################################################################### +# Embedding +# ~~~~~~~~~ +# The next step is to embed the tokens into the hidden states. We use the `embed` function +# compiled in the Relax IRModule to embed the tokens into the hidden states. + +nd_view_func = tvm.get_global_func("vm.builtin.reshape") + + +def embed(tokens, params): + _embed = vm["embed"](tokens, params) + # Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size] + _embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0], _embed.shape[1]])) + return _embed + + +###################################################################### +# Prefill +# ~~~~~~~ +# Before running the forward pass, we first get some help functions for preparation. + +add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence") +begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward") +end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward") + +###################################################################### +# As we are creating a new sequence, we need to call `add_sequence_func` to initialize +# the request. Additionally, we need to call `begin_forward_func` to start the forward pass, +# and `end_forward_func` to end the forward pass. + +if not IS_IN_CI: + seq_id = 0 + add_sequence_func(kv_cache, seq_id) + hidden_states = embed(tokens, params) + begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([input_len])) + logits, kv_cache = vm["prefill"](hidden_states, kv_cache, params) + end_forward_func(kv_cache) + +###################################################################### +# Now we have the output logits from the prefill step. The logits are used to generate the token +# via sampling. Let's sample the token from the logits. +# +# In this tutorial, we simplify the sampling process and pick the token with the highest +# probability. In practice, we should sample the token based on the probability distribution. +# Also, to make the tutorial concise, we execute the sample process on CPU. + + +def sample_token(logits): + logits_np = logits.numpy() + return np.argmax(logits_np) + + +if not IS_IN_CI: + last_token = sample_token(logits) + output_tokens = [last_token] + + +###################################################################### +# Decode +# ~~~~~~ +# After the prefill step, we can start the decode step. The decode step is used to generate the +# token until the end token is generated. We use the `decode` function compiled in the Relax +# IRModule to generate the token. + +if not IS_IN_CI: + print("The generated token:") + + while last_token != tokenizer.eos_token_id: + tokens = tvm.nd.array(np.array([last_token]).astype("int32"), device=dev) + hidden_states = embed(tokens, params) + begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1])) + logits, kv_cache = vm["decode"](hidden_states, kv_cache, params) + + end_forward_func(kv_cache) + last_token = sample_token(logits) + output_tokens.append(last_token) + + print(tokenizer.decode(output_tokens)) diff --git a/docs/index.rst b/docs/index.rst index fdfaa56f7454..5d5d07640134 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -41,7 +41,11 @@ driving its costs down. :maxdepth: 1 :caption: How To - how_to/index + how_to/tutorials/e2e_opt_model + how_to/tutorials/customize_opt + how_to/tutorials/optimize_llm + how_to/tutorials/cross_compilation_and_rpc + how_to/dev/index .. toctree:: :maxdepth: 1 diff --git a/docs/legacy_redirect.py b/docs/legacy_redirect.py index 5e4bdd7430d6..502c7dd0b5bf 100644 --- a/docs/legacy_redirect.py +++ b/docs/legacy_redirect.py @@ -206,7 +206,6 @@ "../../how_to/work_with_relay/using_external_lib.html", ], ["tutorials/language/extern_op.html", "../../how_to/work_with_schedules/extern_op.html"], - ["tutorials/language/intrin_math.html", "../../how_to/work_with_schedules/intrin_math.html"], ["tutorials/language/reduction.html", "../../how_to/work_with_schedules/reduction.html"], ["tutorials/language/scan.html", "../../how_to/work_with_schedules/scan.html"], [ diff --git a/gallery/how_to/work_with_schedules/intrin_math.py b/gallery/how_to/work_with_schedules/intrin_math.py deleted file mode 100644 index 5a35ae1cbd8e..000000000000 --- a/gallery/how_to/work_with_schedules/intrin_math.py +++ /dev/null @@ -1,173 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Intrinsics and Math Functions -============================= -**Author**: `Tianqi Chen `_ - -While TVM supports basic arithmetic operations. In many cases -usually we will need more complicated builtin functions. -For example :code:`exp` to take the exponential of the function. - -These functions are target system dependent and may have different -names of different target platforms. In this tutorial, we will learn -how we can invoke these target specific functions, and how we can unify -the interface via TVM's intrinsic API. -""" -from __future__ import absolute_import, print_function - -import numpy as np - -import tvm -from tvm import te -from tvm.ir import register_op_attr, register_intrin_lowering - -###################################################################### -# Direct Declare Extern Math Call -# ------------------------------- -# The most straight-forward way to call target specific function is via -# extern function call construct in tvm. -# In the following example, we use :any:`tvm.tir.call_pure_extern` to call -# :code:`__expf` function, which is only available under CUDA. -# -n = te.var("n") -A = te.placeholder((n,), name="A") -B = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]), name="B") -s = te.create_schedule(B.op) -num_thread = 64 -bx, tx = s[B].split(B.op.axis[0], factor=num_thread) -s[B].bind(bx, te.thread_axis("blockIdx.x")) -s[B].bind(tx, te.thread_axis("threadIdx.x")) -f = tvm.build(s, [A, B], "cuda", name="myexp") -print(f.imported_modules[0].get_source()) - -###################################################################### -# Unified Intrinsic Call -# ---------------------- -# The above code verifies that direct external call can be used to -# call into device specific functions. -# However, the above way only works for CUDA target with float type. -# Ideally, we want to write same code for any device and any data type. -# -# TVM intrinsic provides the user a mechanism to achieve this, and this -# is the recommended way to solve the problem. -# The following code use te.exp instead, which create an intrinsic call -# :py::func:`tvm.te.exp` to do the exponential. -# -n = te.var("n") -A = te.placeholder((n,), name="A") -B = te.compute(A.shape, lambda i: te.exp(A[i]), name="B") -s = te.create_schedule(B.op) -num_thread = 64 -bx, tx = s[B].split(B.op.axis[0], factor=num_thread) -s[B].bind(bx, te.thread_axis("blockIdx.x")) -s[B].bind(tx, te.thread_axis("threadIdx.x")) -fcuda = tvm.build(s, [A, B], "cuda", name="myexp") -print(fcuda.imported_modules[0].get_source()) -###################################################################### -# We can find that the code works for both CUDA and opencl. -# The same te.exp can also be used for float64 data types. -# -fopencl = tvm.build(s, [A, B], "opencl", name="myexp") -print(fopencl.imported_modules[0].get_source()) - -###################################################################### -# Intrinsic Lowering Rule -# ----------------------- -# When :py:func:`tvm.te.exp` is called, TVM creates an intrinsic Call Expr. -# TVM uses transformation rules to transform the intrinsic -# call to device specific extern calls. -# -# TVM also allows user to customize the rules during runtime. -# The following example customizes CUDA lowering rule for :code:`exp`. -# - - -def my_cuda_math_rule(op): - """Customized CUDA intrinsic lowering rule""" - assert isinstance(op, tvm.tir.Call) - name = op.op.name - assert name.startswith("tir.") - dispatch_name = name[4:] - if op.dtype == "float32": - # call float function - return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0]) - elif op.dtype == "float64": - # call double function - return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0]) - else: - # cannot do translation, return self. - return op - - -register_intrin_lowering("tir.exp", target="cuda", f=my_cuda_math_rule, level=99) -###################################################################### -# Register the rule to TVM with override option to override existing rule. -# Notice the difference between the printed code from previous one: -# our new rule uses math function :code:`expf` instead of -# fast math version :code:`__expf`. -# -fcuda = tvm.build(s, [A, B], "cuda", name="myexp") -print(fcuda.imported_modules[0].get_source()) - -###################################################################### -# Add Your Own Intrinsic -# ---------------------- -# If there is an intrinsic that is not provided by TVM. -# User can easily add new intrinsic by using the intrinsic rule system. -# The following example add an intrinsic :code:`mylog` to the system. -# - - -def mylog(x): - """customized log intrinsic function""" - return tvm.tir.call_intrin(x.dtype, "tir.mylog", x) - - -def my_cuda_mylog_rule(op): - """CUDA lowering rule for log""" - if op.dtype == "float32": - return tvm.tir.call_pure_extern("float32", "logf", op.args[0]) - elif op.dtype == "float64": - return tvm.tir.call_pure_extern("float64", "log", op.args[0]) - else: - return op - - -# new op registration is triggered by registering an attribute of the op -register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) -register_intrin_lowering("tir.mylog", target="cuda", f=my_cuda_mylog_rule, level=99) - -n = te.var("n") -A = te.placeholder((n,), name="A") -B = te.compute(A.shape, lambda i: mylog(A[i]), name="B") -s = te.create_schedule(B.op) -num_thread = 64 -bx, tx = s[B].split(B.op.axis[0], factor=num_thread) -s[B].bind(bx, te.thread_axis("blockIdx.x")) -s[B].bind(tx, te.thread_axis("threadIdx.x")) -fcuda = tvm.build(s, [A, B], "cuda", name="mylog") -print(fcuda.imported_modules[0].get_source()) - -###################################################################### -# Summary -# ------- -# - TVM can call extern target dependent math function. -# - Use intrinsic to defined a unified interface for the functions. -# - For more intrinsics available in tvm, take a look at :any:`tvm.tir` -# - You can customize the intrinsic behavior by defining your own rules. -# From 89a220822d7b980c8d944acaafcaa7ec189b9453 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 4 Sep 2024 21:00:38 +0800 Subject: [PATCH 105/202] [Doc] Deep Dive TensorIR (#17328) This PR adds a new section in the documentation to introduce the TensorIR abstraction, its learning resources, and tutorials. --- docs/conf.py | 2 + docs/deep_dive/tensor_ir/abstraction.rst | 73 +++++ docs/deep_dive/tensor_ir/index.rst | 31 ++ docs/deep_dive/tensor_ir/learning.rst | 253 ++++++++++++++++ docs/deep_dive/tensor_ir/tutorials/README.txt | 2 + .../deep_dive/tensor_ir/tutorials/creation.py | 285 ++++++++++++++++++ .../tensor_ir/tutorials/transformation.py | 173 +++++++++++ docs/index.rst | 9 + 8 files changed, 828 insertions(+) create mode 100644 docs/deep_dive/tensor_ir/abstraction.rst create mode 100644 docs/deep_dive/tensor_ir/index.rst create mode 100644 docs/deep_dive/tensor_ir/learning.rst create mode 100644 docs/deep_dive/tensor_ir/tutorials/README.txt create mode 100644 docs/deep_dive/tensor_ir/tutorials/creation.py create mode 100644 docs/deep_dive/tensor_ir/tutorials/transformation.py diff --git a/docs/conf.py b/docs/conf.py index 1ffc4dcafdb2..8c71f5eb1d55 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -424,6 +424,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder tvm_path.joinpath("docs", "get_started", "tutorials"), tvm_path.joinpath("docs", "how_to", "tutorials"), + tvm_path.joinpath("docs", "deep_dive", "tensor_ir", "tutorials"), ] gallery_dirs = [ @@ -442,6 +443,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder "get_started/tutorials/", "how_to/tutorials/", + "deep_dive/tensor_ir/tutorials/", ] diff --git a/docs/deep_dive/tensor_ir/abstraction.rst b/docs/deep_dive/tensor_ir/abstraction.rst new file mode 100644 index 000000000000..fc11d7f39156 --- /dev/null +++ b/docs/deep_dive/tensor_ir/abstraction.rst @@ -0,0 +1,73 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tir-abstraction: + +Tensor Program Abstraction +-------------------------- +Before we dive into the details of TensorIR, let's first introduce what is a primitive tensor +function. Primitive tensor functions are functions that correspond to a single "unit" of +computational operation. For example, a convolution operation can be a primitive tensor function, +and a fused convolution + relu operation can also be a primitive tensor function. +Usually, a typical abstraction for primitive tensor function implementation contains the following +elements: multi-dimensional buffers, loop nests that drive the tensor computations, and finally, +the compute statements themselves. + +.. code:: python + + from tvm.script import tir as T + + @T.prim_func + def main( + A: T.Buffer((128,), "float32"), + B: T.Buffer((128,), "float32"), + C: T.Buffer((128,), "float32"), + ) -> None: + for i in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + C[vi] = A[vi] + B[vi] + +Key Elements of Tensor Programs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The demonstrated primitive tensor function calculates the element-wise sum of two vectors. +The function: + +- Accepts three **multi-dimensional buffers** as parameters, and generates one **multi-dimensional + buffer** as output. +- Incorporates a solitary **loop nest** ``i`` that facilitates the computation. +- Features a singular **compute statement** that calculates the element-wise sum of the two + vectors. + +Extra Structure in TensorIR +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Crucially, we are unable to execute arbitrary transformations on the program, as certain +computations rely on the loop's sequence. Fortunately, the majority of primitive tensor +functions we focus on possess favorable properties, such as independence among loop iterations. +For instance, the aforementioned program includes block and iteration annotations: + +- The **block annotation** ``with T.block("C")`` signifies that the block is the fundamental + computation unit designated for scheduling. A block may encompass a single computation + statement, multiple computation statements with loops, or opaque intrinsics such as Tensor + Core instructions. +- The **iteration annotation** ``T.axis.spatial``, indicating that variable ``vi`` is mapped + to ``i``, and all iterations are independent. + +While this information isn't crucial for *executing* the specific program, it proves useful when +transforming the program. Consequently, we can confidently parallelize or reorder loops associated +with ``vi``, provided we traverse all the index elements from 0 to 128. diff --git a/docs/deep_dive/tensor_ir/index.rst b/docs/deep_dive/tensor_ir/index.rst new file mode 100644 index 000000000000..432d47116a3c --- /dev/null +++ b/docs/deep_dive/tensor_ir/index.rst @@ -0,0 +1,31 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tensor-ir: + +TensorIR +======== +TensorIR is one of the core abstraction in Apache TVM Unity stack, which is used to +represent and optimize the primitive tensor functions. + +.. toctree:: + :maxdepth: 2 + + abstraction + learning + tutorials/creation + tutorials/transformation diff --git a/docs/deep_dive/tensor_ir/learning.rst b/docs/deep_dive/tensor_ir/learning.rst new file mode 100644 index 000000000000..7ca0a1514fbd --- /dev/null +++ b/docs/deep_dive/tensor_ir/learning.rst @@ -0,0 +1,253 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tir-learning: + +Understand TensorIR Abstraction +=============================== +TensorIR is the tensor program abstraction in Apache TVM, which is one of the standard +machine learning compilation frameworks. The principal objective of tensor program abstraction +is to depict loops and associated hardware acceleration options, including threading, the +application of specialized hardware instructions, and memory access. + +To help our explanations, let us use the following sequence of tensor computations as +a motivating example. Specifically, for two :math:`128 \times 128` matrices ``A`` and ``B``, let us perform the +following two steps of tensor computations. + +.. math:: + + Y_{i, j} &= \sum_k A_{i, k} \times B_{k, j} \\ + C_{i, j} &= \mathbb{relu}(Y_{i, j}) = \mathbb{max}(Y_{i, j}, 0) + + +The above computations resemble a typical primitive tensor function commonly seen in neural networks, +a linear layer with relu activation. We use TensorIR to depict the above computations as follows. + +Before we invoke TensorIR, let's use native Python codes with NumPy to show the computation: + +.. code:: python + + def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): + Y = np.empty((128, 128), dtype="float32") + for i in range(128): + for j in range(128): + for k in range(128): + if k == 0: + Y[i, j] = 0 + Y[i, j] = Y[i, j] + A[i, k] * B[k, j] + for i in range(128): + for j in range(128): + C[i, j] = max(Y[i, j], 0) + +With the low-level NumPy example in mind, now we are ready to introduce TensorIR. The code block +below shows a TensorIR implementation of ``mm_relu``. The particular code is implemented in a +language called TVMScript, which is a domain-specific dialect embedded in python AST. + +.. code:: python + + @tvm.script.ir_module + class MyModule: + @T.prim_func + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +Next, let's invest the elements in the above TensorIR program. + +Function Parameters and Buffers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**The function parameters correspond to the same set of parameters on the numpy function.** + +.. code:: python + + # TensorIR + def mm_relu(A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"]): + ... + # NumPy + def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): + ... + +Here ``A``, ``B``, and ``C`` takes a type named ``T.Buffer``, which with shape +argument ``(128, 128)`` and data type ``float32``. This additional information +helps possible MLC process to generate code that specializes in the shape and data +type. + +**Similarly, TensorIR also uses a buffer type in intermediate result allocation.** + +.. code:: python + + # TensorIR + Y = T.alloc_buffer((128, 128), dtype="float32") + # NumPy + Y = np.empty((128, 128), dtype="float32") + +Loop Iterations +~~~~~~~~~~~~~~~ +**There are also direct correspondence of loop iterations.** + +``T.grid`` is a syntactic sugar in TensorIR for us to write multiple nested iterators. + +.. code:: python + + # TensorIR with `T.grid` + for i, j, k in T.grid(128, 128, 128): + ... + # TensorIR with `range` + for i in range(128): + for j in range(128): + for k in range(128): + ... + # NumPy + for i in range(128): + for j in range(128): + for k in range(128): + ... + +Computational Block +~~~~~~~~~~~~~~~~~~~ +A significant distinction lies in computational statements: +**TensorIR incorporates an additional construct termed** ``T.block``. + +.. code:: python + + # TensorIR + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + # NumPy + vi, vj, vk = i, j, k + if vk == 0: + Y[vi, vj] = 0 + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + +A **block** represents a fundamental computation unit within TensorIR. Importantly, +a block encompasses more information than standard NumPy code. It comprises a set of block axes +``(vi, vj, vk)`` and the computations delineated around them. + +.. code:: python + + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + +The above three lines declare the **key properties** about block axes in the following syntax. + +.. code:: python + + [block_axis] = T.axis.[axis_type]([axis_range], [mapped_value]) + +These three lines convey the following details: + +- They specify the binding of ``vi``, ``vj``, ``vk`` (in this instance, to ``i``, ``j``, ``k``). +- They declare the original range intended for ``vi``, ``vj``, ``vk`` + (the 128 in ``T.axis.spatial(128, i)``). +- They announce the properties of the iterators (spatial, reduce). + +Block Axis Properties +~~~~~~~~~~~~~~~~~~~~~ +Let's delve deeper into the properties of the block axis. These properties signify the axis's +relationship to the computation in progress. The block comprises three axes ``vi``, ``vj``, and +``vk``, meanwhile the block reads the buffer ``A[vi, vk]``, ``B[vk, vj]`` and writs the buffer +``Y[vi, vj]``. Strictly speaking, the block performs (reduction) updates to Y, which we label +as write for the time being, as we don't require the value of Y from another block. + +Significantly, for a fixed value of ``vi`` and ``vj``, the computation block yields a point +value at a spatial location of ``Y`` (``Y[vi, vj]``) that is independent of other locations in ``Y`` +(with different ``vi``, ``vj`` values). We can refer to ``vi``, ``vj`` as **spatial axes** since +they directly correspond to the start of a spatial region of buffers that the block writes to. +The axes involved in reduction (``vk``) are designated as **reduce axes**. + +Why Extra Information in Block +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +One crucial observation is that the additional information (block axis range and their properties) +makes the block to be **self-contained** when it comes to the iterations that it is supposed to +carry out independent from the external loop-nest ``i, j, k``. + +The block axis information also provides additional properties that help us to validate the correctness of the +external loops that are used to carry out the computation. For example, the above code block will result in an +error because the loop expects an iterator of size 128, but we only bound it to a for loop of size 127. + +.. code:: python + + # wrong program due to loop and block iteration mismatch + for i in range(127): + with T.block("C"): + vi = T.axis.spatial(128, i) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + error here due to iterator size mismatch + ... + +Sugars for Block Axes Binding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In situations where each of the block axes is directly mapped to an outer loop iterator, +we can use ``T.axis.remap`` to declare the block axis in a single line. + +.. code:: python + + # SSR means the properties of each axes are "spatial", "spatial", "reduce" + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + +which is equivalent to + +.. code:: python + + vi = T.axis.spatial(range_of_i, i) + vj = T.axis.spatial(range_of_j, j) + vk = T.axis.reduce (range_of_k, k) + +So we can also write the programs as follows. + +.. code:: python + + @tvm.script.ir_module + class MyModuleWithAxisRemapSugar: + @T.prim_func + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) diff --git a/docs/deep_dive/tensor_ir/tutorials/README.txt b/docs/deep_dive/tensor_ir/tutorials/README.txt new file mode 100644 index 000000000000..bbbd7d3e5a20 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/README.txt @@ -0,0 +1,2 @@ +Deep Dive: TensorIR +------------------- diff --git a/docs/deep_dive/tensor_ir/tutorials/creation.py b/docs/deep_dive/tensor_ir/tutorials/creation.py new file mode 100644 index 000000000000..51481fb2e325 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/creation.py @@ -0,0 +1,285 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tir-creation: + +TensorIR Creation +----------------- +In this section, we will introduce the methods to write a TensorIR function +in Apache TVM Unity. This tutorial presumes familiarity with the fundamental concepts of TensorIR. +If not already acquainted, please refer to :ref:`tir-learning` initially. + +.. note:: + + This tutorial concentrates on the construction of **standalone** TensorIR functions. The + techniques presented here are not requisite for end users to compile Relax models. + +""" + +###################################################################### +# Create TensorIR using TVMScript +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# The most straightforward way to create a TensorIR function via TVMScript. +# TVMScript is a TVM Python dialect that represents TensorIR in TVM. +# +# .. important:: +# +# While TVMScript employs Python syntax and AST, ensuring full compatibility +# with Python tools like auto-completion and linting, it is not a native Python +# language and cannot be executed by a Python interpreter. +# +# More precisely, the decorator **@tvm.script** extracts the Python AST from +# the decorated function, subsequently parsing it into TensorIR. +# +# Standard Format +# *************** +# Let's take an example of ``mm_relu`` from :ref:`tir-learning`. Here is the complete +# format of the ir_module and in TVMScript: + + +import numpy as np +import tvm +from tvm.script import ir as I +from tvm.script import tir as T + + +@I.ir_module +class MyModule: + @T.prim_func + def mm_relu( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i in range(128): + for j in range(128): + for k in range(128): + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + T.reads(A[vi, vk], B[vk, vj]) + T.writes(Y[vi, vj]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i in range(128): + for j in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + T.reads(Y[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# Concise with Syntactic Sugar +# **************************** +# For ease of writing, we can employ the following syntactic sugar to +# streamline the code: +# +# - Utilize ``T.grid`` to condense nested loops; +# - Employ ``T.axis.remap`` to abbreviate block iterator annotations; +# - Exclude ``T.reads`` and ``T.writes`` for blocks whose content can +# be inferred from the block body; + + +@I.ir_module +class ConciseModule: + @T.prim_func + def mm_relu( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# We can use the following code to verify that the two modules are equivalent: + +print(tvm.ir.structural_equal(MyModule, ConciseModule)) + +###################################################################### +# Interactive with Python Variables +# ********************************* +# Despite TVMScript not being executed by a Python interpreter, limited +# interaction with Python is feasible. For instance, Python variables can +# be used to ascertain the shape and data type of a TensorIR. + +# Python variables +M = N = K = 128 +dtype = "float32" + + +# IRModule in TVMScript +@I.ir_module +class ConciseModuleFromPython: + @T.prim_func + def mm_relu( + A: T.Buffer((M, K), dtype), + B: T.Buffer((K, N), dtype), + C: T.Buffer((M, N), dtype), + ): + Y = T.alloc_buffer((M, N), dtype) + for i, j, k in T.grid(M, N, K): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.cast(T.float32(0), dtype) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(M, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) + + +###################################################################### +# Check the equivalence: + +print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython)) + + +###################################################################### +# TensorIR Function with Dynamic Shapes +# ************************************* +# Despite TVMScript not being executed by a Python interpreter, limited +# interaction with Python is feasible. For instance, Python variables can +# be used to ascertain the shape and data type of a TensorIR. + + +@I.ir_module +class DynamicShapeModule: + @T.prim_func + def mm_relu(a: T.handle, b: T.handle, c: T.handle): + # Dynamic shape definition + M, N, K = T.int32(), T.int32(), T.int32() + + # Bind the input buffers with the dynamic shapes + A = T.match_buffer(a, [M, K], dtype) + B = T.match_buffer(b, [K, N], dtype) + C = T.match_buffer(c, [M, N], dtype) + Y = T.alloc_buffer((M, N), dtype) + for i, j, k in T.grid(M, N, K): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.cast(T.float32(0), dtype) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(M, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) + + +###################################################################### +# Now let's check the runtime dynamic shape inference: + + +def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int): + A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32")) + B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32")) + C = tvm.nd.array(np.zeros((m, n), dtype="float32")) + lib(A, B, C) + return C.numpy() + + +# Compile lib only once +dyn_shape_lib = tvm.build(DynamicShapeModule, target="llvm") +# Able to handle different shapes +print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4)) +print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128)) + +###################################################################### +# Create TensorIR using Tensor Expression +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Often, the specifics of TensorIR are disregarded in favor of expressing the computation more +# succinctly, leading to the pragmatic generation of TensorIR. This is where Tensor Expression +# (TE) becomes relevant. +# +# Tensor Expression (TE) serves as a domain-specific language delineating a sequence of +# computations through an expression-like API. +# +# .. note:: +# +# Tensor Expression comprises two components within the TVM stack: the expression and the +# schedule. The expression is the domain-specific language embodying the computation pattern, +# precisely what we're addressing in this section. Conversely, the TE schedule is the legacy +# scheduling method, has been superseded by the TensorIR schedule in the TVM Unity stack. +# +# Create Static-Shape Functions +# ***************************** +# We use the same example of ``mm_relu`` from the last subsection to demonstrate the +# TE creation method. + +from tvm import te + +A = te.placeholder((128, 128), "float32", name="A") +B = te.placeholder((128, 128), "float32", name="B") +k = te.reduce_axis((0, 128), "k") +Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y") +C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C") + +###################################################################### +# Here ``te.compute`` takes the signature ``te.compute(output_shape, fcompute)``. +# And the fcompute function describes how we want to compute the value of each +# element ``Y[i, j]`` for a given index: +# +# .. code:: python +# +# lambda i, j: te.sum(A[i, k] * B[k, j], axis=k) +# +# The aforementioned lambda expression encapsulates the computation: +# :math:`Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}`. Upon defining the computation, +# we can formulate a TensorIR function by incorporating the pertinent parameters of interest. +# In this specific instance, we aim to construct a function with two input parameters **A, B** +# and one output parameter **C**. + +te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"}) +TEModule = tvm.IRModule({"mm_relu": te_func}) +TEModule.show() + +###################################################################### +# Create Dynamic-Shape Functions +# ****************************** +# We can also create a dynamic-shape function using Tensor Expression. The only difference +# is that we need to specify the shape of the input tensors as symbolic variables. + +# Declare symbolic variables +M, N, K = te.var("m"), te.var("n"), te.var("k") +A = te.placeholder((M, N), "float32", name="A") +B = te.placeholder((K, N), "float32", name="B") +k = te.reduce_axis((0, K), "k") +Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y") +C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name="C") + +dyn_te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"}) +DynamicTEModule = tvm.IRModule({"mm_relu": dyn_te_func}) +DynamicTEModule.show() diff --git a/docs/deep_dive/tensor_ir/tutorials/transformation.py b/docs/deep_dive/tensor_ir/tutorials/transformation.py new file mode 100644 index 000000000000..1dcf8e7ab5c8 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/transformation.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tir-transform: + +Transformation +-------------- +In this section, we will get to the main ingredients of the compilation flows - +transformations of primitive tensor functions. +""" + +###################################################################### +# In the :ref:`previous section `, we have given an example of how to write +# ``mm_relu`` using TensorIR. In practice, there can be multiple ways to implement +# the same functionality, and each implementation can result in different performance. +# +# .. note:: +# This tutorial primarily illustrates the application of TensorIR Transformation, +# rather than delving into optimization techniques. +# +# First, let's take a look at the implementation of ``mm_relu`` in the previous section: + +import tvm +from tvm.script import ir as I +from tvm.script import tir as T + + +@I.ir_module +class MyModule: + @T.prim_func + def main( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + Y = T.alloc_buffer((128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# Before we transform the function, let's first evaluate the performance of the +# original implementation. + +import numpy as np + +a_np = np.random.uniform(size=(128, 128)).astype("float32") +b_np = np.random.uniform(size=(128, 128)).astype("float32") +c_np = a_np @ b_np + +a_nd = tvm.nd.array(a_np) +b_nd = tvm.nd.array(b_np) +c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + + +def evaluate(mod: tvm.IRModule): + lib = tvm.build(mod, target="llvm") + # check correctness + lib(a_nd, b_nd, c_nd) + np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5) + # evaluate performance + f_timer = lib.time_evaluator("main", tvm.cpu()) + print(f_timer(a_nd, b_nd, c_nd)) + + +evaluate(MyModule) + +###################################################################### +# Initialization Schedule +# *********************** +# We initiate the process of code transformation by establishing a Schedule helper class, +# utilizing the provided **MyModule** as input. + +sch = tvm.tir.Schedule(MyModule) + +###################################################################### +# Loop Tiling +# *********** +# Subsequently, we execute the requisite operations to acquire a reference to +# block **Y** and its associated loops. + +block_Y = sch.get_block("Y") +i, j, k = sch.get_loops(block_Y) + +###################################################################### +# We now proceed to execute the transformations. The initial modification involves +# splitting loop ``j`` into two separate loops, with the inner loop possessing a +# length of 4. It is crucial to understand that the transformation process is procedural; +# thus, inadvertent execution of the block twice will yield an error stating the +# non-existence of variable ``j``. + +j0, j1 = sch.split(j, factors=[None, 8]) + +###################################################################### +# The outcome of the transformation can be examined, as it is retained within ``sch.mod``. + +sch.mod.show() + +###################################################################### +# Following the initial transformation phase, two supplementary loops, ``j_0`` and ``j_1``, +# have been generated with respective ranges of 32 and 4. The subsequent +# action involves reordering these two loops. + +sch.reorder(j0, k, j1) +sch.mod.show() +evaluate(sch.mod) + +###################################################################### +# Leverage Localities +# ******************* +# Subsequently, we will execute two additional transformation steps to achieve a different +# variant. First, we employ a primitive known as **reverse_compute_at** to relocate block +# **C** to an inner loop of **Y**. + +block_C = sch.get_block("C") +sch.reverse_compute_at(block_C, j0) +sch.mod.show() + +###################################################################### +# Rewrite Reduction +# ***************** +# Until now, the reduction initialization and update step have been maintained together +# within a single block body. This amalgamated form facilitates loop transformations, +# as the outer loops ``i``, ``j`` of initialization and updates generally need to remain +# synchronized. +# +# Following the loop transformations, we can segregate the initialization of Y's elements +# from the reduction update via the **decompose_reduction** primitive. + +sch.decompose_reduction(block_Y, k) +sch.mod.show() +evaluate(sch.mod) + +###################################################################### +# Trace the Transformation +# ************************ +# TensorIR schedule is a procedural language, and the transformation is executed in a +# step-by-step manner. We can trace the transformation by printing the schedule or the +# history of the schedule. +# +# We've already see the schedule by printing ``sch.mod``. We can also print the history +# of the schedule by ``sch.trace``. + +sch.trace.show() + +###################################################################### +# Alternatively, we can output the IRModule in conjunction with the historical trace. + +sch.show() diff --git a/docs/index.rst b/docs/index.rst index 5d5d07640134..2eec0cb99e97 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,15 @@ driving its costs down. how_to/tutorials/cross_compilation_and_rpc how_to/dev/index +.. The Deep Dive content is comprehensive +.. we maintain a ``maxdepth`` of 2 to display more information on the main page. + +.. toctree:: + :maxdepth: 2 + :caption: Deep Dive + + deep_dive/tensor_ir/index + .. toctree:: :maxdepth: 1 :caption: API Reference From 56273574e6a250ddb3d2af15c8159e8913636b8c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 4 Sep 2024 10:51:15 -0500 Subject: [PATCH 106/202] [Relax] Allow dynamic shape argument to R.reshape (#17218) Prior to this commit, the `shape` argument to `R.reshape` was required to either be an in-line `relax::ShapeExpr`, or a variable that had been bound to a `relax::ShapeExpr` within the current function. As a result, shapes that were provided as function arguments or that were produced by another operation (e.g. `R.tensor_to_shape`) would unnecessarily trigger an error. This commit updates the `VMBuiltinLower` pass to instead check that the argument has `relax::ShapeStructInfo`. Closes https://github.com/apache/tvm/issues/17217 --- src/relax/backend/vm/lower_runtime_builtin.cc | 36 +++++----- tests/python/relax/test_vm_builtin_lower.py | 65 +++++++++++++++++++ 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index a3867ae92448..4757561b549b 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -49,6 +49,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); + } else if (call->op == tensor_to_shape_op_) { + return TensorToShape(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -112,22 +114,15 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 2); ICHECK(call_node->struct_info_.defined()); auto arg = call_node->args[1]; - CHECK(arg->IsInstance() || arg->IsInstance()) - << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound " - "to a ShapeExpr"; - - if (arg->IsInstance()) { - return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); - } else { - // Handling the case when arg is VarNode - Optional _bound_val = LookupBinding(Downcast(arg)); - ICHECK(_bound_val.defined()); - Expr bound_val = _bound_val.value(); - CHECK(bound_val->IsInstance()) - << "VMBuiltinLower expects bound value to be a ShapeExpr"; - return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), - {GetStructInfo(call_node)}); - } + + CHECK(arg->struct_info_->IsInstance()) + << "TypeError: " + << "VMBuiltinLower expects the shape arg of R.reshape " + << "to be a ShapeExpr or VarNode bound to a ShapeExpr. " + << "However, in expression " << call_node << ", the shape argument " << arg + << " has struct info " << arg->struct_info_; + + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } Expr ShapeOf(const Call& call_node) { @@ -136,6 +131,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr TensorToShape(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + + return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr ICHECK(call_node->args.size() == 1); @@ -194,6 +196,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -211,6 +214,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; + const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index 984f9f958ca2..daa59793cc47 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -82,5 +82,70 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: relax.transform.LowerRuntimeBuiltin()(Before) +def test_vm_reshape_may_be_var(): + """R.reshape does not require an in-line R.shape""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32"), shape: R.Shape): + R.func_attr({"relax.force_pure": True}) + reshape = R.reshape(A, shape) + return reshape + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32"), shape: R.Shape): + R.func_attr({"relax.force_pure": True}) + reshape = R.call_packed( + "vm.builtin.reshape", + A, + shape, + sinfo_args=R.Tensor(shape, dtype="float32"), + ) + return reshape + + After = relax.transform.VMBuiltinLower()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_vm_reshape_using_tensor_to_shape(): + """Shape argument of R.reshape may come from tensor_to_shape""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): + R.func_attr({"relax.force_pure": True}) + shape = R.tensor_to_shape(shape_tensor) + reshape = R.reshape(A, shape) + return reshape + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): + R.func_attr({"relax.force_pure": True}) + + shape = R.call_packed( + "vm.builtin.tensor_to_shape", + shape_tensor, + sinfo_args=R.Shape(ndim=2), + ) + reshape = R.call_packed( + "vm.builtin.reshape", + A, + shape, + sinfo_args=R.Tensor(shape, dtype="float32"), + ) + return reshape + + After = relax.transform.VMBuiltinLower()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() From e19541d1e224110399cc81d1cfeecec365020e69 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 03:25:46 +0900 Subject: [PATCH 107/202] [Relax][PyTorch][Bugfix] Update `layer_norm` converter to support `immutable_list` for `normalized_shape` (#17330) handle when the 2nd arg is a type of `immutable_list` --- python/tvm/relax/frontend/torch/fx_translator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 245bb4cffb57..49ff6c6b6d51 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1227,6 +1227,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: def _layer_norm(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore + from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore x = self.env[node.args[0]] @@ -1235,8 +1236,8 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: if node.target not in self.named_modules: # static or symbolic arg = node.args[1] - if isinstance(arg, tuple): - value = arg + if isinstance(arg, (immutable_list, tuple)): + value = tuple(arg) else: try: value = self.env[arg] From 19b66bfed2f255401b235c8d08a1381322fab315 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 03:26:08 +0900 Subject: [PATCH 108/202] [Relax][PyTorch] Add support for torchvision.ops.stochastic_depth (#17300) * add a test for stochastic_depth * add support for torchvision.ops.stochastic_depth --- .../tvm/relax/frontend/torch/fx_translator.py | 1 + tests/python/relax/test_frontend_from_fx.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 49ff6c6b6d51..21a0b2d5642a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1672,6 +1672,7 @@ def create_convert_map(self): "softmax": self._softmax, "log_softmax": self._log_softmax, "dropout": lambda node: self.env[node.args[0]], + "stochastic_depth": lambda node: self.env[node.args[0]], "clamp": self._clamp, "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), "leaky_relu": self._leakyrelu, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e191775a63b2..35a9bc71bf98 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import fx from torch.nn import Module +import torchvision import tvm from tvm import relax @@ -1212,6 +1213,37 @@ def main( verify_model(Dropout2(), input_info, {}, expected1) +def test_stochastic_depth(): + input_info = [([1, 3, 10, 10], "float32")] + + class StochasticDepth1(Module): + def __init__(self): + super().__init__() + self.stochastic_depth = torchvision.ops.StochasticDepth(0.5, mode="row") + + def forward(self, x): + return self.stochastic_depth(x) + + class StochasticDepth2(Module): + def forward(self, x): + return torchvision.ops.stochastic_depth(x, 0.5, mode="row", training=False) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1 + R.output(gv) + return gv + + verify_model(StochasticDepth1(), input_info, {}, expected1) + verify_model(StochasticDepth2(), input_info, {}, expected1) + + def test_layernorm(): input_info = [([1, 3, 10, 10], "float32")] From 73b138b1924cd1a6c5877430f98ea39697c6654a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 06:08:10 +0900 Subject: [PATCH 109/202] [Rust] Remove mxnet dependency and re-enable rust example (#17293) * use torchvision's resnet18 instead of mxnet * re-enable rust example * update readme --- rust/tvm/README.md | 2 +- rust/tvm/examples/resnet/README.md | 2 +- rust/tvm/examples/resnet/build.rs | 6 ----- rust/tvm/examples/resnet/src/build_resnet.py | 28 ++++++++++---------- rust/tvm/examples/resnet/src/main.rs | 5 ---- 5 files changed, 16 insertions(+), 27 deletions(-) diff --git a/rust/tvm/README.md b/rust/tvm/README.md index b1bb4687679e..3455975ad81d 100644 --- a/rust/tvm/README.md +++ b/rust/tvm/README.md @@ -26,7 +26,7 @@ You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/t The goal of this crate is to provide bindings to both the TVM compiler and runtime APIs. First train your **Deep Learning** model using any major framework such as -[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/). +[PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/). Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators. The Rust bindings are composed of a few crates: diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md index d6e32f7fa768..ad76ac0048a0 100644 --- a/rust/tvm/examples/resnet/README.md +++ b/rust/tvm/examples/resnet/README.md @@ -21,7 +21,7 @@ This end-to-end example shows how to: * build `Resnet 18` with `tvm` from Python * use the provided Rust frontend API to test for an input image -To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +To run the example with pretrained resnet weights, first `tvm` and `torchvision` must be installed for the python build. To install torchvision for cpu, run `pip install torch torchvision` and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). * **Build the example**: `cargo build diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index 45e4d6d658d5..9e3a76433ffc 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -21,10 +21,6 @@ use anyhow::{Context, Result}; use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { - // Currently disabled, as it depends on the no-longer-supported - // mxnet repo to download resnet. - - /* let out_dir = std::env::var("CARGO_MANIFEST_DIR")?; let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"); let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt"); @@ -57,7 +53,5 @@ fn main() -> Result<()> { ); println!("cargo:rustc-link-search=native={}", out_dir); - */ - Ok(()) } diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index df02dd78f57c..4e8ae01c413b 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -17,22 +17,18 @@ # under the License. import argparse -import csv import logging -from os import path as osp -import sys import shutil +from os import path as osp import numpy as np - +import torch +import torchvision import tvm -from tvm import te -from tvm import relay, runtime -from tvm.relay import testing -from tvm.contrib import graph_executor, cc from PIL import Image +from tvm import relay, runtime +from tvm.contrib import cc, graph_executor from tvm.contrib.download import download_testdata -from mxnet.gluon.model_zoo.vision import get_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -64,11 +60,16 @@ def build(target_dir): """Compiles resnet18 with TVM""" - # Download the pretrained model in MxNet's format. - block = get_model("resnet18_v1", pretrained=True) + # Download the pretrained model from Torchvision. + weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + torch_model = torchvision.models.resnet18(weights=weights).eval() + + input_shape = [1, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(torch_model, input_data) + input_infos = [("data", input_data.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) - shape_dict = {"data": (1, 3, 224, 224)} - mod, params = relay.frontend.from_mxnet(block, shape_dict) # Add softmax to do classification in last layer. func = mod["main"] func = relay.Function( @@ -93,7 +94,6 @@ def build(target_dir): def download_img_labels(): """Download an image and imagenet1k class labels for test""" - from mxnet.gluon.utils import download synset_url = "".join( [ diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 0ea8c4cf8bb5..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -31,10 +31,6 @@ use tvm_rt::graph_rt::GraphRt; use tvm_rt::*; fn main() -> anyhow::Result<()> { - // Currently disabled, as it depends on the no-longer-supported - // mxnet repo to download resnet. - - /* let dev = Device::cpu(0); println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); @@ -138,7 +134,6 @@ fn main() -> anyhow::Result<()> { "input image belongs to the class `{}` with probability {}", label, max_prob ); - */ Ok(()) } From e65aab6a4f55f4b405ef2713f842d6a3b761151b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 22:30:12 +0900 Subject: [PATCH 110/202] [Relax][PyTorch][Fix] use`_convert_torch_tensor_to_relax()` where possible (#17335) * use `_convert_torch_tensor_to_relax` where possible * add type annotation --- python/tvm/relax/frontend/torch/fx_translator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 21a0b2d5642a..6e60c3bb6fc4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,7 +62,7 @@ def _fetch_attr(self, model, target: str): return attr_itr @staticmethod - def _convert_data_type(input_type, env: Optional[Dict] = None): + def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): """converts the PyTorch scalar type input_type to a TVM dtype.""" import torch # type: ignore @@ -1206,9 +1206,8 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params[module.bias] - dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype)) - running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) - running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + running_mean = self._convert_torch_tensor_to_relax(module.running_mean) + running_var = self._convert_torch_tensor_to_relax(module.running_var) eps = module.eps res_tuple = self.block_builder.emit( @@ -1769,7 +1768,7 @@ def from_fx( dtype = self._convert_data_type(str(param.data.dtype)) if dtype in ("float32", "float16"): if not keep_params_as_input: - self.params[param] = relax.const(param.data.cpu().numpy(), dtype) + self.params[param] = self._convert_torch_tensor_to_relax(param) else: raise ValueError("Unsupported data type for model parameters: %s" % dtype) # Translate the model. From 823763db5b35aec04fb021b47d3f8b06db08e0b0 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 23:01:09 +0900 Subject: [PATCH 111/202] [Apps] Remove mxnet dependency from /apps/ios_rpc (#17299) use torchvision's mobilenet_v2 instead of mxnet --- apps/ios_rpc/tests/ios_rpc_mobilenet.py | 37 +++++++++++++++++-------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/apps/ios_rpc/tests/ios_rpc_mobilenet.py b/apps/ios_rpc/tests/ios_rpc_mobilenet.py index 1872cf678779..85a430317765 100644 --- a/apps/ios_rpc/tests/ios_rpc_mobilenet.py +++ b/apps/ios_rpc/tests/ios_rpc_mobilenet.py @@ -23,7 +23,6 @@ import coremltools import numpy as np import tvm -from mxnet import gluon from PIL import Image from tvm import relay, rpc from tvm.contrib import coreml_runtime, graph_executor, utils, xcode @@ -51,6 +50,8 @@ def compile_metal(src, target): def prepare_input(): + from torchvision import transforms + img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" img_name = "cat.png" synset_url = "".join( @@ -62,22 +63,36 @@ def prepare_input(): ] ) synset_name = "imagenet1000_clsid_to_human.txt" - img_path = download_testdata(img_url, "cat.png", module="data") + img_path = download_testdata(img_url, img_name, module="data") synset_path = download_testdata(synset_url, synset_name, module="data") with open(synset_path) as f: synset = eval(f.read()) - image = Image.open(img_path).resize((224, 224)) + input_image = Image.open(img_path) - image = np.array(image) - np.array([123.0, 117.0, 104.0]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :] - return image.astype("float32"), synset + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + input_tensor = preprocess(input_image) + input_batch = input_tensor.unsqueeze(0) + return input_batch.detach().cpu().numpy(), synset def get_model(model_name, data_shape): - gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) - mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + import torch + import torchvision + + torch_model = getattr(torchvision.models, model_name)(weights="IMAGENET1K_V1").eval() + input_data = torch.randn(data_shape) + scripted_model = torch.jit.trace(torch_model, input_data) + + input_infos = [("data", input_data.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) + # we want a probability so add a softmax operator func = mod["main"] func = relay.Function( @@ -90,7 +105,7 @@ def get_model(model_name, data_shape): def test_mobilenet(host, port, key, mode): temp = utils.tempdir() image, synset = prepare_input() - model, params = get_model("mobilenetv2_1.0", image.shape) + model, params = get_model("mobilenet_v2", image.shape) def run(mod, target): with relay.build_config(opt_level=3): From 26fec76b93806587c5c9bf614b5d3aa218b6e53f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 5 Sep 2024 11:45:04 -0500 Subject: [PATCH 112/202] [CI][Hexagon] Forward gtest tests into pytest as separate tests (#17334) * [CI][Hexagon] Forward gtest tests into pytest as separate tests Prior to this commit, all Hexagon test cases in `tests/cpp-runtime/hexagon` were executed as part of a single unit test in pytest. This can take a significant portion of the total timeout in CI (~50 minutes out of a 2-hour timeout). While the hexagon tests are split out onto 8 separate runners, having a single large test can cause timeouts on whichever runner happens to receive it. This commit exposes each unit test from `tests/cpp-runtime/hexagon` into a separate unit test in pytest, to avoid these timeouts. * lint fix --- .../test_hexagon/test_run_unit_tests.py | 132 +++++++++++++++++- 1 file changed, 130 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_run_unit_tests.py b/tests/python/contrib/test_hexagon/test_run_unit_tests.py index cd4e5c9b0d66..1651783e3456 100644 --- a/tests/python/contrib/test_hexagon/test_run_unit_tests.py +++ b/tests/python/contrib/test_hexagon/test_run_unit_tests.py @@ -15,18 +15,139 @@ # specific language governing permissions and limitations # under the License. -""" capture gtest output and return over FFI """ +# pylint: disable=redefined-outer-name + +"""capture gtest output and return over FFI""" import tvm +import tvm.testing from tvm.contrib.hexagon.session import Session +unit_test_name = tvm.testing.parameter( + "HexagonUserDMATest.wait", + "HexagonUserDMATest.poll", + "HexagonUserDMATest.bad_copy", + "HexagonUserDMATest.sync_dma", + "HexagonUserDMATest.async_dma_wait", + "HexagonUserDMATest.async_dma_poll", + "HexagonUserDMATest.pipeline", + "HexagonUserDMATest.pipeline_write_queue", + "HexagonUserDMATest.overflow_ring_buffer", + "HexagonUserDMATest.sync_dma_bypass", + "HexagonUserDMATest.sync_dma_bypass_vtcm_to_vtcm", + "HexagonUserDMATest.sync_dma_bypass_", + "HexagonBuffer.default_scope", + "HexagonBuffer.ddr_scope", + "HexagonBuffer.vtcm_scope", + "HexagonBuffer.invalid_scope", + "HexagonBuffer.micro_copies_corresponding_regions", + "HexagonBuffer.micro_copies_src_bigger", + "HexagonBuffer.micro_copies_dest_bigger", + "HexagonBuffer.micro_copies_src_overlaps_dest_region", + "HexagonBuffer.micro_copies_dest_overlaps_src_region", + "HexagonBuffer.micro_copies_discontiguous_regions", + "HexagonBuffer.micro_copies_invalid_size", + "HexagonBuffer.macro_copies_adjacent_corresponding_regions_merged", + "HexagonBuffer.macro_copies_discontiguous_regions_not_merged", + "HexagonBuffer.macro_copies_overlapping_regions_merged", + "HexagonBuffer.copy_from", + "HexagonBuffer.copy_from_invalid_size", + "HexagonBuffer.copy_from_smaller_size", + "HexagonBuffer.nd", + "HexagonBuffer.nd_copy_from", + "HexagonBuffer.1d_copy_from_1d", + "HexagonBuffer.2d_copy_from_1d", + "HexagonBuffer.1d_copy_from_2d", + "HexagonBuffer.nd_copy_from_nd_invalid_size", + "HexagonBuffer.nd_copy_from_nd_smaller_size", + "HexagonBuffer.md_copy_from_nd", + "HexagonBuffer.copy_to", + "HexagonBuffer.nd_copy_to", + "RingBufferTest.zero_size_ring_buffer", + "RingBufferTest.in_flight", + "RingBufferTest.next", + "RingBufferTest.full", + "RingBufferTest.wrap", + "RingBufferTest.wrap_corner", + "RingBufferTest.half_in_flight", + "RingBufferTest.half_in_flight_blocked", + "QueuedRingBufferTest.invalid_queue", + "QueuedRingBufferTest.two_queues", + "QueuedRingBufferTest.group_end_before_group_start", + "QueuedRingBufferTest.group_restart", + "QueuedRingBufferTest.zero_size_group", + "QueuedRingBufferTest.in_flight_before_group_end", + "QueuedRingBufferTest.group_of_one", + "QueuedRingBufferTest.group_of_two", + "QueuedRingBufferTest.group_of_three", + "QueuedRingBufferTest.two_groups_of_two", + "QueuedRingBufferTest.two_queues_two_groups_of_two", + "HexagonVtcmPoolTest.basic", + "HexagonVtcmPoolTest.small_allocations", + "HexagonVtcmPoolTest.no_free_vtcm", + "HexagonVtcmPoolTest.not_enough_free_vtcm", + "HexagonVtcmPoolTest.free_with_wrong_size", + "HexagonVtcmPoolTest.free_alloc_combinations", + "HexagonVtcmPoolTest.find_allocation", + "HexagonVtcmPoolTest.find_smallest_allocation_combinations", + "HexagonVtcmPoolTest.vtcm_alignment", + "HexagonThreadManagerTest.ctor_edge_cases", + "HexagonThreadManagerTest.init", + "HexagonThreadManagerTest.dispatch", + "HexagonThreadManagerTest.dispatch_wait", + "HexagonThreadManagerTest.wait_signal", + "HexagonThreadManagerTest.re_signal", + "HexagonThreadManagerTest.re_wait", + "HexagonThreadManagerTest.wait_signal_x2", + "HexagonThreadManagerTest.signal_wait", + "HexagonThreadManagerTest.sync_from_to", + "HexagonThreadManagerTest.sync_from_to_self", + "HexagonThreadManagerTest.sync_from_to_x2", + "HexagonThreadManagerTest.sync_from_to_all", + "HexagonThreadManagerTest.pipe_fill", + "HexagonThreadManagerTest.pipe_overflow", + "HexagonThreadManagerTest.producer_consumer", + "HexagonThreadManagerTest.producer_consumer_signal_wait", + "HexagonThreadManagerTest.thread_order", + "HexagonThreadManagerTest.thread_order_signal_wait", + "HexagonThreadManagerTest.dispatch_writes", + "HexagonThreadManagerTest.threads_for_resource_types", + "HexagonUtilsActivationsBlockizeTest.prepare_nhwc", + "HexagonUtilsActivationsBlockizeTest.blockize_hwc_16b", + "HexagonUtilsActivationsBlockizeTest.deblockize_hwc_16b", + "HexagonUtilsWeightsChunkifyTest.calculate_num_weight_chunks", + "HexagonUtilsWeightsChunkifyTest.prepare_hwio", + "HexagonUtilsWeightsChunkifyTest.chunkify_hwio_16b", + "HexagonUtilsQuantActivationsBlockizeTest.prepare_nhwc", + "HexagonUtilsQuantActivationsBlockizeTest.blockize_hwc_8b", + "HexagonUtilsQuantActivationsBlockizeTest.deblockize_hwc_8b", + "HexagonUtilsQuantWeightsChunkifyTest.calculate_num_weight_chunks", + "HexagonUtilsQuantWeightsChunkifyTest.prepare_hwio", + "HexagonUtilsQuantWeightsChunkifyTest.chunkify_hwio_8b", + "HexagonDeviceAPITest.global", + "HexagonDeviceAPITest.alloc_free_cpu", + "HexagonDeviceAPITest.alloc_free_hex", + "HexagonDeviceAPITest.alloc_errors", + "HexagonDeviceAPITest.free_errors", + "HexagonDeviceAPITest.allocnd_free_cpu", + "HexagonDeviceAPITest.allocnd_free_hex", + "HexagonDeviceAPITest.allocnd_free_hex_vtcm", + "HexagonDeviceAPITest.allocnd_erros", + "HexagonDeviceAPITest.alloc_scalar", + "HexagonDeviceAPITest.DISABLED_alloc_free_diff_dev", + "HexagonDeviceAPITest.runtime_buffer_manager", + "HexagonDeviceAPITest.thread_manager", + "HexagonDeviceAPITest.user_dma", + "HexagonDeviceAPITest.vtcm_pool", +) + # use pytest -sv to observe gtest output # use --gtest_args to pass arguments to gtest # for example to run all "foo" tests twice and observe gtest output run # pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" @tvm.testing.requires_hexagon -def test_run_unit_tests(hexagon_session: Session, gtest_args): +def test_run_unit_tests(hexagon_session: Session, gtest_args, unit_test_name): """Try running gtest unit tests and capture output and error code""" try: func = hexagon_session._rpc.get_function("hexagon.run_unit_tests") @@ -40,6 +161,13 @@ def test_run_unit_tests(hexagon_session: Session, gtest_args): ) raise + # Prepend the unit test name, so command-line arguments still take + # precedence, but CI runs each gtest as a separate pytest case. + if gtest_args: + gtest_args = f"--gtest_filter={unit_test_name} {gtest_args}" + else: + gtest_args = f"--gtest_filter={unit_test_name}" + gtest_error_code_and_output = func(gtest_args) gtest_error_code = int(gtest_error_code_and_output.splitlines()[0]) gtest_output = gtest_error_code_and_output.split("\n", 1)[-1] From dbe95c43b2afde26eab428181d47cfc939d153c1 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 6 Sep 2024 20:45:36 +0800 Subject: [PATCH 113/202] [MSC][BugFix] Bugfix for strided_slice op (#17315) support strided_slice --- src/contrib/msc/core/codegen/base_codegen.h | 6 +- src/contrib/msc/core/ir/graph_builder.cc | 13 +++- .../msc/core/transform/bind_named_params.cc | 2 +- src/contrib/msc/core/utils.cc | 67 ++++++++++++++++++- src/contrib/msc/core/utils.h | 54 +++++++++++++-- .../contrib/test_msc/test_graph_build.py | 3 - .../contrib/test_msc/test_translate_relax.py | 4 -- .../test_msc/test_translate_tensorflow.py | 4 -- .../contrib/test_msc/test_translate_torch.py | 3 - 9 files changed, 128 insertions(+), 28 deletions(-) diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index 19d8b524b9e2..acaac896a153 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -179,17 +179,17 @@ class BaseCodeGen { return 1; } if (node->scope.size() == scopes_.top().size()) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top())) << "Scope mismatch, node " << node->scope << " compare to current " << scopes_.top(); return 0; } else if (node->scope.size() == scopes_.top().size() + 1) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) << "Scope increase mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.push(node->scope); return 1; } else if (node->scope.size() == scopes_.top().size() - 1) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) << "Scope decrease mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.pop(); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index d35a462579d9..a968df4204a2 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -23,6 +23,7 @@ #include "graph_builder.h" +#include #include namespace tvm { @@ -71,6 +72,13 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { for (const auto& arg : op->args) { if (const auto* s_node = arg.as()) { values_.push_back(StringUtils::ToString(s_node->value)); + } else if (const auto* s_node = arg.as()) { + bool all_values = + std::all_of(s_node->fields.begin(), s_node->fields.end(), + [](const relax::Expr& e) { return e->IsInstance(); }); + if (all_values) { + values_.push_back(StringUtils::ToString(s_node->fields)); + } } } } @@ -337,6 +345,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype << " should has special type, get " << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); + } else if (input_types[i] != "input" && arg->IsInstance()) { + attrs.Set(input_types[i], StringUtils::ToString(arg)); } } for (size_t i = call->args.size(); i < input_types.size(); i++) { @@ -371,7 +381,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; - } else if (const auto* tuple_node = arg.as()) { + } else if (input_types[i] == "input" && arg->IsInstance()) { + const auto* tuple_node = arg.as(); for (const auto& f : tuple_node->fields) { ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 5ba1ca30eb1c..6256fae05f83 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -84,7 +84,7 @@ std::tuple, Map> NormalizeNamedBindings( if (auto opt = obj.as()) { return opt.value(); } else if (auto opt = obj.as()) { - const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, key->name_hint()); + const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); return Constant(opt.value(), StructInfo(), span); } else { LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 5fcbe924ae1c..c6e74d42843d 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -280,6 +280,8 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { } } else if (const auto* n = obj.as()) { obj_string = ToString(n->value); + } else if (const auto* n = obj.as()) { + obj_string = ToString(n->fields); } else { std::ostringstream obj_des; obj_des << obj; @@ -288,7 +290,7 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { return obj_string; } -bool StringUtils::CompareArrays(const Array& left, const Array& right, int size) { +bool ArrayUtils::CompareArrays(const Array& left, const Array& right, int size) { if (left.size() == right.size() && left.size() == 0) { return true; } @@ -311,6 +313,37 @@ bool StringUtils::CompareArrays(const Array& left, const Array& return true; } +PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { + size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos; + PrimExpr accumulate = Integer(1); + for (size_t i = 0; i < t_pos; i++) { + accumulate = accumulate * array[i]; + } + return accumulate; +} + +bool ArrayUtils::Broadcastable(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + const auto& lp = lhs[i]; + const auto& rp = rhs[i]; + if (lp->IsInstance() && rp->IsInstance()) { + continue; + } + if (lp->IsInstance() && rp->IsInstance() && + Downcast(lp)->value == Downcast(rp)->value) { + continue; + } + if (lp->IsInstance() && Downcast(lp)->value == 1) { + continue; + } + return false; + } + return true; +} + const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) { if (value.size() == 0) { return span; @@ -353,6 +386,10 @@ const Map SpanUtils::GetAttrs(const Span& span) { return attrs; } +const Span SpanUtils::CreateWithAttr(const String& key, const String& value) { + return SetAttr(Span(), key, value); +} + const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs_num, bool as_relax) { Array input_types; @@ -370,6 +407,14 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs } else if (optype == "full" && as_relax) { input_types.push_back("shape"); input_types.push_back("input"); + } else if (optype == "strided_slice") { + input_types.push_back("input"); + if (inputs_num > 1) { + input_types.push_back("axes"); + input_types.push_back("begin"); + input_types.push_back("end"); + input_types.push_back("strides"); + } } else if (optype == "triu") { input_types.push_back("input"); input_types.push_back("k"); @@ -454,13 +499,31 @@ const Array ExprUtils::GetInputTypes(const RelayCall& call) { return GetInputTypes(optype, call->args.size(), false); } +const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { + const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); + if (suffix.size() > 0) { + return name + "_" + suffix; + } + return name; +} + +const Array ExprUtils::GetShape(const Expr& expr) { + const auto& shape_opt = Downcast(relax::GetStructInfo(expr))->GetShape(); + ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr; + return shape_opt.value(); +} + +const DataType ExprUtils::GetDataType(const Expr& expr) { + return Downcast(relax::GetStructInfo(expr))->dtype; +} + TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") .set_body_typed([](const String& key, const String& value) -> Span { - return SpanUtils::SetAttr(Span(), key, value); + return SpanUtils::CreateWithAttr(key, value); }); TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr") diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 6c39a8d0a16a..d7758cc23d8b 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -175,13 +176,6 @@ class StringUtils { * \return The String. */ TVM_DLL static const String ToString(const runtime::ObjectRef& obj); - - /*! - * \brief Compare String arrays. - * \return Whether two array are same. - */ - TVM_DLL static bool CompareArrays(const Array& left, const Array& right, - int size = -1); }; /*! @@ -238,6 +232,10 @@ class ArrayUtils { return new_array; } + /*! + * \brief Product elements in the arrays. + * \return The producted array + */ template TVM_DLL static const Array> Product(const Array>& arrays) { Array> p_arrays; @@ -260,6 +258,24 @@ class ArrayUtils { } return p_arrays; } + + /*! + * \brief Compare String arrays. + * \return Whether two array are same. + */ + TVM_DLL static bool CompareArrays(const Array& left, const Array& right, + int size = -1); + /*! + * \brief Accumulate array. + * \return The accumulate result + */ + TVM_DLL static PrimExpr Accumulate(const Array& array, int pos = -1); + + /*! + * \brief Check if lhs array is broadcastable to rhs. + * \return broadcastable + */ + TVM_DLL static bool Broadcastable(const Array& lhs, const Array& rhs); }; /*! @@ -284,6 +300,12 @@ class SpanUtils { * \return The Attrs Map. */ TVM_DLL static const Map GetAttrs(const Span& span); + + /*! + * \brief Create a span with value. + * \return The created Span. + */ + TVM_DLL static const Span CreateWithAttr(const String& key, const String& value); }; /*! @@ -365,6 +387,24 @@ class ExprUtils { TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 0) { return GetScalar(constant->data, i); } + + /*! + * \brief Get name in span. + * \return The name. + */ + TVM_DLL static const String GetSpanName(const Expr& expr, const String& suffix = ""); + + /*! + * \brief Get shape of expr. + * \return The shape. + */ + TVM_DLL static const Array GetShape(const Expr& expr); + + /*! + * \brief Get dtype of expr. + * \return The shape. + */ + TVM_DLL static const DataType GetDataType(const Expr& expr); }; } // namespace msc diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 069ffff53bd7..d02767208206 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -17,8 +17,6 @@ """ Test graph builder && graph. """ -import pytest - import torch from torch import fx from torch.nn import Module @@ -1101,7 +1099,6 @@ def forward(self, data): verify_model(GetAttr1(), input_info, expected) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test graph builder for getitem""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index e8b7149a68a2..66aa90a625ea 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -17,8 +17,6 @@ """ Test translate from relax. """ -import pytest - import torch from torch import fx from torch.nn import Module @@ -57,7 +55,6 @@ def _run_relax(relax_mod): relax_exec = tvm.relax.build(relax_mod, target) vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) res = vm_runner["main"](*args) - return _tvm_runtime_to_np(res) rt_mod = tvm_codegen.to_relax( @@ -629,7 +626,6 @@ def forward(self, data): _verify_model(GetAttr1(), input_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test relax translator for getitem""" diff --git a/tests/python/contrib/test_msc/test_translate_tensorflow.py b/tests/python/contrib/test_msc/test_translate_tensorflow.py index 61f8ce1a973c..cb4ea3c02e4b 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorflow.py +++ b/tests/python/contrib/test_msc/test_translate_tensorflow.py @@ -18,8 +18,6 @@ """ Test translate from tensorflow. """ -import pytest - from packaging import version as package_version import numpy as np @@ -504,7 +502,6 @@ def _test_stridedslice( verify_model(graph_def, golden, **io_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_stridedslice(): """test tensorflow translator for stridedslice""" @@ -1065,7 +1062,6 @@ def _test_slice_operation_input(input_value, begin_value, size_value): verify_model(graph_def, golden, **io_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_slice(): """test tensorflow translator for slice""" diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 60dcbb293a51..f3e01493d96a 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -17,8 +17,6 @@ """ Test translate from torch. """ -import pytest - import numpy as np import torch @@ -589,7 +587,6 @@ def forward(self, data): verify_model(GetAttr1(), input_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test torch translator for getitem""" From f33cc8f5597edf6687fb54535ced5d292a4dd778 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 6 Sep 2024 22:14:32 +0900 Subject: [PATCH 114/202] [Relax][PyTorch] Add support for `torch.ops.aten.sym_size.int` (#17342) * add a test for `torch.ops.aten.sym_size.int` * add support for `torch.ops.aten.sym_size.int` * cleanup --- .../tvm/relax/frontend/torch/fx_translator.py | 7 ++++++ tests/python/relax/test_frontend_from_fx.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6e60c3bb6fc4..aed38d7c49ea 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1464,6 +1464,12 @@ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: ########## Others ########## + def _sym_size_int(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + idx = node.args[1] + return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + def _size(self, node: fx.node.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1680,6 +1686,7 @@ def create_convert_map(self): "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, "interpolate": self._interpolate, + "sym_size.int": self._sym_size_int, "size": self._size, "getattr": self._getattr, "getitem": self._getitem, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 35a9bc71bf98..78fc7abdf748 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3929,5 +3929,30 @@ def main( ) +def test_sym_size_int(): + class SymSizeInt1(Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.sym_size.int(x, self.dim) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 4), dtype="float32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + lv: R.Tensor((), dtype="int32") = R.const(3, "int32") + gv: R.Tensor((), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(SymSizeInt1(dim=1), [([1, 3, 4], "float32")], {}, Expected1) + verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1) + + if __name__ == "__main__": tvm.testing.main() From f432ebd5f553c166c8dccc1d0900c7ef8628ad5c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 Sep 2024 08:17:11 -0500 Subject: [PATCH 115/202] [Relax] Update GlobalVar name in AttachGlobalSymbol (#17202) * [IR] Implement cross-IR call-map collection Prior to this commit, the `relax.transform.DeadCodeElimination` only considered calls from Relax to TIR when identifying unused functions. This would erroneously remove TIR functions that are called indirectly. This commit adds a new utility `tvm.ir.analysis.collect_call_map`, which can collect the call map of an `IRModule` across both Relax and TIR, using it in Relax's `DeadCodeElimination` transform. * [Relax] Update GlobalVar name in AttachGlobalSymbol Prior to this commit, the `relax.transform.AttachGlobalSymbol` pass could produce a PrimFunc whose `"global_symbol"` attribute does not match the name of the `GlobalVar`. As a result, the PackedFunc that is provided by the compiled module (defined by the `"global_symbol"`) does not match the PackedFunc that is required by the Relax VM (defined by the `GlobalVar` name). This commit updates `AttachGlobalSymbol` to replace the `GlobalVar` of any function whose `"global_symbol"` is updated. Closes https://github.com/apache/tvm/issues/17176 * lint fixes * lint fixes --- include/tvm/ir/analysis.h | 63 ++++++++++++ include/tvm/ir/replace_global_var.h | 57 +++++++++++ python/tvm/ir/__init__.py | 3 + python/tvm/ir/_ffi_analysis_api.py | 22 +++++ python/tvm/ir/analysis.py | 44 +++++++++ src/ir/analysis.cc | 49 ++++++++++ src/ir/replace_global_var.cc | 63 ++++++++++++ src/relax/analysis/collect_call_map.cc | 56 +++++++++++ src/relax/transform/attach_global_symbol.cc | 48 ++++++--- src/relax/transform/dead_code_elimination.cc | 94 +++++------------- src/relax/transform/replace_global_var.cc | 66 +++++++++++++ src/tir/analysis/collect_call_map.cc | 57 +++++++++++ src/tir/transforms/replace_global_var.cc | 68 +++++++++++++ .../ir/analysis/test_collect_call_map.py | 97 +++++++++++++++++++ .../test_transform_attach_global_symbol.py | 6 +- .../test_transform_dead_code_elimination.py | 60 +++++++++++- 16 files changed, 762 insertions(+), 91 deletions(-) create mode 100644 include/tvm/ir/analysis.h create mode 100644 include/tvm/ir/replace_global_var.h create mode 100644 python/tvm/ir/_ffi_analysis_api.py create mode 100644 python/tvm/ir/analysis.py create mode 100644 src/ir/analysis.cc create mode 100644 src/ir/replace_global_var.cc create mode 100644 src/relax/analysis/collect_call_map.cc create mode 100644 src/relax/transform/replace_global_var.cc create mode 100644 src/tir/analysis/collect_call_map.cc create mode 100644 src/tir/transforms/replace_global_var.cc create mode 100644 tests/python/ir/analysis/test_collect_call_map.py diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h new file mode 100644 index 000000000000..afe18792dee0 --- /dev/null +++ b/include/tvm/ir/analysis.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/analysis.h + * + * Analysis routines that must function across multiple IR types for + * correctness. For example, identifying unused functions, when both TIR + * + */ +#ifndef TVM_IR_ANALYSIS_H_ +#define TVM_IR_ANALYSIS_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +class CalleeCollector { + public: + /* \brief Functor to be registered for IR types + * + * Should be implemented for each `BaseFunc` subclass. + * Implementation should call `CalleeCollector::Mark` for each + * `GlobalVar` in the function. + */ + using FType = NodeFunctor; + TVM_DLL static FType& vtable() { + static FType inst; + return inst; + } + + virtual ~CalleeCollector() {} + + /* \brief Collect the GlobalVar in a function */ + virtual void Mark(GlobalVar gvar) = 0; +}; + +Map> CollectCallMap(const IRModule& mod); + +} // namespace ir +} // namespace tvm + +#endif // TVM_IR_ANALYSIS_H_ diff --git a/include/tvm/ir/replace_global_var.h b/include/tvm/ir/replace_global_var.h new file mode 100644 index 000000000000..c15dd5f4e5ad --- /dev/null +++ b/include/tvm/ir/replace_global_var.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/replace_global_var.h + * + * \brief A utility to replace GlobalVar instances across all TVM IR + * types in an IRMdoule. + */ +#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_ +#define TVM_IR_REPLACE_GLOBAL_VAR_H_ + +#include + +namespace tvm { +namespace transform { + +/*! + * \brief Replace GlobalVar instances across any IR type. + * + * \param mod The module to update + * + * \param replacements The map, where each entry maps from an old + * `GlobalVar` to the new `GlobalVar` that should replace it. + * + * \return The updated IRModule + */ +TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map replacements); + +struct GlobalVarReplacer { + using FType = NodeFunctor)>; + TVM_DLL static FType& vtable() { + static FType inst; + return inst; + } +}; + +} // namespace transform +} // namespace tvm + +#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_ diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 939a5f638381..fdac74a0b4ec 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=unused-import """Common data structures across all IR variants.""" + from . import diagnostics, instrument, transform from .adt import Constructor, TypeData from .affine_type import TensorAffineType, TupleAffineType @@ -61,3 +62,5 @@ TypeVar, ) from .type_relation import TypeCall, TypeRelation + +from . import analysis diff --git a/python/tvm/ir/_ffi_analysis_api.py b/python/tvm/ir/_ffi_analysis_api.py new file mode 100644 index 000000000000..0013ec3b5026 --- /dev/null +++ b/python/tvm/ir/_ffi_analysis_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ir.analysis""" + +import tvm._ffi + + +tvm._ffi._init_api("ir.analysis", __name__) diff --git a/python/tvm/ir/analysis.py b/python/tvm/ir/analysis.py new file mode 100644 index 000000000000..11fa819e2275 --- /dev/null +++ b/python/tvm/ir/analysis.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=unused-import + +"""Common analysis across all IR variants.""" + +from typing import Dict, List + +import tvm +from . import _ffi_analysis_api as _ffi + + +def collect_call_map( + module: "tvm.ir.IRModule", +) -> Dict["tvm.ir.GlobalVar", List["tvm.ir.GlobalVar"]]: + """Collect the call map of a module + + Parameters + ---------- + module: tvm.ir.IRModule + The module to inspect + + Returns + ------- + call_map: Dict[tvm.ir.GlobalVar, List[tvm.ir.GlobalVar]] + A map from functions to the subroutines they call. + + """ + return _ffi.CollectCallMap(module) diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc new file mode 100644 index 000000000000..9de36b0a28af --- /dev/null +++ b/src/ir/analysis.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/analysis.cc + * \brief Analysis functions that must span multiple IR types + */ +#include + +#include "../support/ordered_set.h" + +namespace tvm { +namespace ir { + +Map> CollectCallMap(const IRModule& mod) { + struct CalleeCollectorImpl : CalleeCollector { + void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } + support::OrderedSet gvars; + }; + + Map> call_map; + for (const auto& [gvar, base_func] : mod->functions) { + CalleeCollectorImpl collector; + CalleeCollector::vtable()(base_func, &collector); + call_map.Set(gvar, Array{collector.gvars.begin(), collector.gvars.end()}); + } + return call_map; +} + +TVM_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); + +} // namespace ir +} // namespace tvm diff --git a/src/ir/replace_global_var.cc b/src/ir/replace_global_var.cc new file mode 100644 index 000000000000..08d66d0e7cf2 --- /dev/null +++ b/src/ir/replace_global_var.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/replace_global_var.cc + * \brief IRModule transform to replace GlobalVar instances across any IR type. + */ + +#include + +#include + +namespace tvm { +namespace transform { + +IRModule ReplaceGlobalVar(IRModule mod, Map replacements) { + std::vector to_remove; + IRModule updates; + + const auto& vtable = GlobalVarReplacer::vtable(); + + for (const auto& [old_gvar, old_func] : mod->functions) { + auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar); + auto new_func = vtable(old_func, replacements); + + if (!new_gvar.same_as(old_gvar)) { + to_remove.push_back(old_gvar); + } + if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) { + updates->Add(new_gvar, new_func); + } + } + + if (to_remove.size() || updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + for (const auto& old_gvar : to_remove) { + write_ptr->Remove(old_gvar); + } + write_ptr->Update(updates); + } + return mod; +} + +TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/analysis/collect_call_map.cc b/src/relax/analysis/collect_call_map.cc new file mode 100644 index 000000000000..3e0170d3444d --- /dev/null +++ b/src/relax/analysis/collect_call_map.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relax/analysis/collect_call_map.cc + * + * \brief Collect cross-IR call graph + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { +using ir::CalleeCollector; + +struct Visitor : ExprVisitor { + explicit Visitor(CalleeCollector* collector) : collector(collector) {} + CalleeCollector* collector; + void VisitExpr_(const GlobalVarNode* node) override { collector->Mark(GetRef(node)); } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { + Visitor visitor{collector}; + visitor(Downcast(func)); + }); + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) {}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 9b2a561c7fec..a517d5a035e2 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -22,6 +22,8 @@ */ #include +#include +#include #include #include @@ -32,26 +34,46 @@ namespace transform { Pass AttachGlobalSymbol() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - mod.CopyOnWrite(); - String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); - std::vector > updates; + IRModule updates; + Map gvar_updates; + + for (const auto& [gvar, func] : mod->functions) { + Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); - for (auto& p : mod->functions) { - BaseFunc func = p.second; // TODO(tvm-team): re-enable once fix relax integration part - // if (func->GetAttr(tvm::attr::kGlobalSymbol)) continue; + // if (old_name) continue; + + Optional new_name; + BaseFunc new_func; + if (auto* prim_func = func.as()) { - updates.emplace_back(p.first, - WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, - c_prefix + p.first->name_hint)); + new_name = c_prefix + gvar->name_hint; + new_func = WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); } else if (auto* relax_func = func.as()) { - updates.emplace_back(p.first, WithAttr(GetRef(relax_func), - tvm::attr::kGlobalSymbol, p.first->name_hint)); + new_name = gvar->name_hint; + new_func = WithAttr(GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); + } + + if (new_name.defined() && (!old_name.defined() || old_name.value() != new_name.value())) { + updates->Add(gvar, new_func); + if (new_name.value() != gvar->name_hint) { + GlobalVar new_gvar(new_name.value()); + if (auto sinfo = gvar->struct_info_.as()) { + UpdateStructInfo(new_gvar, sinfo.value()); + } + + gvar_updates.Set(gvar, new_gvar); + } } } - for (const auto& pair : updates) { - mod->Add(pair.first, pair.second, true); + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + + if (gvar_updates.size()) { + mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates); + } } return mod; }; diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 9591b45595f9..4305554342ad 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -32,6 +32,7 @@ * Any binding blocks that are left empty will be removed by the normalizer. */ +#include #include #include #include @@ -42,89 +43,40 @@ namespace tvm { namespace relax { -/** - * \brief Detects all the functions that can be possibly called by entry function. - */ -class CallTracer : public ExprVisitor { - public: - explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} {} - - void VisitExpr_(const GlobalVarNode* op) final { - auto gvar = GetRef(op); - called_funcs_.insert(gvar); - if (auto func = mod_->functions.Get(gvar)) { - if (const auto* function_node = func.as()) { - VisitExpr(GetRef(function_node)); - } - // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. - } else { - // The GlobalVar is not contained in the IRModule. While the - // input IRModule is ill-formed, this specific case is allowed - // for use with `relax.transform.ApplyPassToFunction`. If this - // occurs, DCE should not remove any internal functions from the - // IRModule, as their removal is only valid if we have a - // complete call graph. - all_callees_found_ = false; - } - } +IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set& entry_funcs) { + auto call_map = ir::CollectCallMap(mod); + + std::unordered_set reachable = entry_funcs; + std::vector to_visit(entry_funcs.begin(), entry_funcs.end()); + bool all_callees_in_module = true; - void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } + while (to_visit.size()) { + GlobalVar visiting = to_visit.back(); + to_visit.pop_back(); - void VisitExpr_(const FunctionNode* func_node) final { - auto func = GetRef(func_node); - if (visiting_.find(func) == visiting_.end()) { - visiting_.insert(func); - for (auto param : func_node->params) { - ExprVisitor::VisitExpr(param); + if (auto it = call_map.find(visiting); it != call_map.end()) { + for (GlobalVar callee : (*it).second) { + if (!reachable.count(callee)) { + reachable.insert(callee); + to_visit.push_back(callee); + } } - ExprVisitor::VisitExpr(func_node->body); + } else { + all_callees_in_module = false; } } - void Trace(std::string entry) { - called_funcs_.insert(mod_->GetGlobalVar(entry)); - auto main_func = mod_->Lookup(entry); - VisitExpr(main_func); - } - - /* \brief Check if a function is unreachable - * - * \param gvar The function to be checked - * - * \return True if the function can be proven to be unreachable, - * either directly or indirectly, from an external caller. - * Otherwise, false. - */ - bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const { - return all_callees_found_ && !called_funcs_.count(gvar); - } - - private: - IRModule mod_; - - /* \brief Whether all callees could be located within the IRModule */ - bool all_callees_found_{true}; - - // Record the names of all encountered functions. - std::unordered_set called_funcs_; - - // Record the expressions that are being visited. - std::unordered_set visiting_; -}; - -IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set& entry_funcs) { - CallTracer tracer(mod); - for (const auto& gvar : entry_funcs) { - tracer.VisitExpr(gvar); + if (!all_callees_in_module) { + return mod; } std::vector to_remove; - for (const auto& kv : mod->functions) { + for (const auto& [gvar, func] : mod->functions) { // The tracer contains all user-provided entry functions, all // externally-callable functions, and anything that is directly or // indirectly accessible from an entry function. - if (tracer.CheckIfProvablyUnreachable(kv.first)) { - to_remove.push_back(kv.first); + if (!reachable.count(gvar)) { + to_remove.push_back(gvar); } } diff --git a/src/relax/transform/replace_global_var.cc b/src/relax/transform/replace_global_var.cc new file mode 100644 index 000000000000..b81b831036ff --- /dev/null +++ b/src/relax/transform/replace_global_var.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relax/transform/replace_global_var.cc + * + * \brief GlobalVar replacement across IR types + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { +using tvm::transform::GlobalVarReplacer; + +struct Mutator : ExprMutator { + Map replacements; + explicit Mutator(Map replacements) : replacements(replacements) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* node) override { + auto gvar = GetRef(node); + return replacements.Get(gvar).value_or(gvar); + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& func, + Map replacements) -> BaseFunc { + Mutator mutator(replacements); + return Downcast(mutator(Downcast(func))); + }); + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& func, + Map) -> BaseFunc { + return Downcast(func); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/tir/analysis/collect_call_map.cc b/src/tir/analysis/collect_call_map.cc new file mode 100644 index 000000000000..98f7585c6b79 --- /dev/null +++ b/src/tir/analysis/collect_call_map.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/tir/analysis/collect_call_map.cc + * + * \brief Collect cross-IR call graph + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +using ir::CalleeCollector; + +struct Visitor : StmtExprVisitor { + explicit Visitor(CalleeCollector* collector) : collector(collector) {} + CalleeCollector* collector; + void VisitExpr_(const CallNode* node) override { + StmtExprVisitor::VisitExpr_(node); + if (auto opt_gvar = node->op.as()) { + collector->Mark(opt_gvar.value()); + } + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { + Visitor visitor{collector}; + visitor(Downcast(func)->body); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/replace_global_var.cc b/src/tir/transforms/replace_global_var.cc new file mode 100644 index 000000000000..8ef8ba9276b0 --- /dev/null +++ b/src/tir/transforms/replace_global_var.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/tir/transforms/replace_global_var.cc + * + * \brief GlobalVar replacement across IR types + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +using tvm::transform::GlobalVarReplacer; + +struct Mutator : StmtExprMutator { + Map replacements; + explicit Mutator(Map replacements) : replacements(replacements) {} + + PrimExpr VisitExpr_(const CallNode* node) override { + auto call = Downcast(StmtExprMutator::VisitExpr_(node)); + if (auto old_gvar = call->op.as()) { + if (auto new_gvar = replacements.Get(old_gvar.value())) { + call.CopyOnWrite()->op = new_gvar.value(); + } + } + return call; + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& obj, + Map replacements) -> BaseFunc { + Mutator mutator(replacements); + auto func = Downcast(obj); + auto new_body = mutator(func->body); + + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + return func; + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/ir/analysis/test_collect_call_map.py b/tests/python/ir/analysis/test_collect_call_map.py new file mode 100644 index 000000000000..9068bffc5fe0 --- /dev/null +++ b/tests/python/ir/analysis/test_collect_call_map.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, List + +import tvm +import tvm.testing +from tvm.ir import GlobalVar + +from tvm.script import ir as I, tir as T, relax as R + +from tvm.ir.analysis import collect_call_map + + +def _build_str_map(call_map: Dict[GlobalVar, List[GlobalVar]]) -> Dict[str, List[str]]: + return { + caller.name_hint: [callee.name_hint for callee in callees] + for caller, callees in call_map.items() + } + + +def test_collect_relax_to_relax(): + @I.ir_module + class Module: + @R.function + def main(): + return Module.subroutine() + + @R.function + def subroutine(): + return R.tuple() + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +def test_collect_relax_to_tir(): + @I.ir_module + class Module: + @R.function + def main() -> R.Prim("int32"): + return Module.subroutine(R.prim_value(T.int32(42))) + + @T.prim_func + def subroutine(i: T.int32) -> T.int32: + return i + 1 + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +def test_collect_tir_to_tir(): + @I.ir_module + class Module: + @T.prim_func + def main() -> T.int32: + return Module.subroutine(42) + + @T.prim_func + def subroutine(i: T.int32) -> T.int32: + return i + 1 + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index 680df969474a..39f6d061f721 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -89,7 +89,7 @@ def test_system_lib_prefix(): class Before: I.module_attrs({"system_lib_prefix": "hello_"}) - @T.prim_func + @T.prim_func(private=True) def tir_zeros(x: T.Buffer((2), "float32")) -> None: x[0] = T.float32(0) @@ -103,13 +103,13 @@ class Expected: I.module_attrs({"system_lib_prefix": "hello_"}) @T.prim_func - def tir_zeros(x: T.Buffer((2), "float32")) -> None: + def hello_tir_zeros(x: T.Buffer((2), "float32")) -> None: T.func_attr({"global_symbol": "hello_tir_zeros"}) x[0] = T.float32(0) @R.function def main() -> R.Tensor: - gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), dtype="float32")) + gv0 = R.call_tir(Expected.hello_tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 before = Before diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 6546d09777b0..65970d64550e 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -346,6 +346,42 @@ def main( assert check_if_func_exists(new_mod, "unused_func") +def test_preserve_indirectly_used_prim_func(): + @tvm.script.ir_module + class InputModule: + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir( + InputModule.tir_add_tensors, + [x, w], + out_sinfo=R.Tensor((16, 16), "float32"), + ) + return gv0 + + @T.prim_func(private=True) + def tir_add_tensors( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ): + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = InputModule.tir_add_float32(x[vi, vj], y[vi, vj]) + + @T.prim_func(private=True) + def tir_add_float32(x: T.float32, y: T.float32) -> T.float32: + return x + y + + mod = InputModule + assert mod + new_mod = DeadCodeElimination()(mod) + + tvm.ir.assert_structural_equal(mod, new_mod) + + def test_multiple_unused_funcs(): @tvm.script.ir_module class InputModule: @@ -399,7 +435,11 @@ def main( ) lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( - lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv0, + lv1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( lv2, axes=[0, 3, 1, 2] @@ -428,7 +468,11 @@ def main( ) lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( - lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv0, + lv1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) R.output(lv2) gv3 = R.astype(lv2, dtype="float16") @@ -464,7 +508,11 @@ def main( gv_w, axes=[0, 2, 3, 1] ) lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( - lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv1, + lv2, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) # dead instruction -> usee lv1 also dead. lv4: R.Tensor((2, 3, 28, 28), dtype="float32") = R.permute_dims( @@ -491,7 +539,11 @@ def main( gv_w, axes=[0, 2, 3, 1] ) lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( - lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv1, + lv2, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) R.output(lv3) return lv3 From 491a0f69aabcf812cc552df7666038414ca79a8f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 Sep 2024 08:32:31 -0500 Subject: [PATCH 116/202] [Relax] Require correct input/output shapes `R.call_tir` (#17285) Prior to this commit, the Relax well-formed checker validated arguments provided to Relax functions, but did not validate arguments provided to `R.call_tir`. As a result, incorrect arguments from Relax to TIR would not be checked until runtime, if at all. This commit updates the well-formed checker to verify that `R.call_tir` has received the correct arguments, and has the correct output shape specified in the `out_sinfo` parameter. Initial implementation performed the validation as part of `FNormalize`, to maximize coverage of this check. This increased end-to-end compilation time by ~10%, and so the check was requested to be restricted to the well-formed checker. Expensive operator-specific validation is now performed in the new `FValidate` attribute. --- include/tvm/relax/op_attr_types.h | 27 + src/relax/analysis/well_formed.cc | 11 + src/relax/op/op.cc | 291 +++++++++- src/relax/transform/fuse_tir.cc | 3 +- ...istributed_transform_propagate_sharding.py | 8 - .../python/relax/test_analysis_well_formed.py | 514 +++++++++++++++++- tests/python/relax/test_ast_printer.py | 9 +- tests/python/relax/test_dataflow_inplace.py | 10 +- tests/python/relax/test_dataflow_pattern.py | 2 +- tests/python/relax/test_frontend_dynamo.py | 7 +- tests/python/relax/test_frontend_nn_op.py | 18 +- tests/python/relax/test_transform.py | 6 +- .../test_transform_dead_code_elimination.py | 30 +- tests/python/relax/test_transform_fuse_ops.py | 8 +- .../test_transform_fuse_ops_by_pattern.py | 18 +- .../test_transform_lazy_transform_params.py | 20 +- ...test_transform_rewrite_dataflow_reshape.py | 25 +- tests/python/relax/test_tvmscript_parser.py | 15 +- tests/python/relax/test_vm_build.py | 12 +- 19 files changed, 928 insertions(+), 106 deletions(-) diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 291bee597c03..0ddc2baefbef 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -56,6 +56,14 @@ using FCallPacked = String; * expressed in multiple syntactically valid and semantically * equivalent forms, to normalize to a single representation. * + * Note: `FNormalize` is applied for each expression as part of the + * `relax::BlockBuilder`. While operator-specific validation may + * be performed within the `FNormalize` implementation, ensuring + * that errors are caught as early as possible, this should only be + * used when validation is fast to apply. If the validation logic + * may be slow, it should instead be implemented in `FValidate`, + * which is only run as part of the well-formed checker. + * * \param bb The BlockBuilder context. * * \param call The call to be normalized. It is provided by-value, to @@ -63,6 +71,25 @@ using FCallPacked = String; */ using FNormalize = runtime::TypedPackedFunc; +/*! + * \brief The function type of a validation function. + * + * A validation function is used to define constraints that should be + * verified for an operator as part of the well-formed checker. + * + * Note: `FValidate` is only applied as part of the well-formed + * checker. While this minimizes overhead while compiling Relax, + * this delay between generating an ill-formed `relax::Call` and + * identifying the ill-formed call may complicate debugging. If + * the validation logic is very fast to check, and doing so would + * not introduce a signficant overhead, consider validating as part + * of `FNormalize`, which is applied by the block builder for each + * `relax::Call`. + * + * \param call The call to be validated. + */ +using FValidate = runtime::TypedPackedFunc; + /*! \brief The function type of a legalization function. * * A legalization function is used to replace a `relax::Call` with diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 626fadda273d..235059ece2aa 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor, << after_normalize); } } + + if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) { + try { + func_validate(GetRef(call)); + } catch (std::exception& err) { + Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for " + << call->op << " identified error: \n" + << err.what()); + } + } } void VisitExpr_(const IfNode* op) final { @@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_map symbolic_var_func_map_; tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); + tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; bool WellFormed(Variant obj, bool check_struct_info) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 0a840248ffe8..3e0f0eba313a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include #include @@ -242,15 +243,195 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla // call_tir +/* If possible, infer a legal value of `arg_sinfo` + * + * The `R.call_tir` operator and its variants accept an `arg_sinfo` + * parameter, which specifies the shape of the tensor or tensors + * returned by a PrimFunc. This output shape must be compatible with + * the shape defined by the PrimFunc's signature. + * + * For dynamic shapes, it is not always possible to infer the output + * of a TIR PrimFunc from its inputs. For example, a PrimFunc that + * accepts input buffer `T.Buffer([16], "float32")` and output buffer + * `T.Buffer([M, N], "float32")` infers the values of `M` and `N` from + * the shape of the provided output buffer. + * + * If the arguments provided are not compatible with the PrimFunc's + * signature, an error will be raised. If the arguments are + * compatible with the PrimFunc's signature, but are not sufficient to + * determine the output's StructInfo, then `NullOpt` will be returned. + * + * \param func_sinfo The StructInfo of the TIR callee. + * \param arg_sinfo The StructInfo of the argument tuple. + * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument, + * if present. + * \param opt_inplace_indices For `R.call_tir_inplace`, an array of + * indices indicating which outputs are constructed from in-place + * mutation of the inputs. See + * `CallTIRInplaceAttrs::inplace_indices` for more details. + * + * \return The `arg_sinfo`, if it can be inferred from the arguments. + * Otherwise, NullOpt. + */ +static Optional InferCallTIROutputStructInfoFromArguments( + StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, + Optional> opt_inplace_indices) { + auto opt_callee_sinfo = func_sinfo.as(); + CHECK(opt_callee_sinfo) << "TypeError: " + << "The first argument to `R.call_tir` must be a function, " + << "but instead received argument of type " << func_sinfo; + auto callee_sinfo = opt_callee_sinfo.value(); + + CHECK(callee_sinfo->params.defined()) + << "ValueError: " + << "The first argument to `R.call_tir` must be a function " + << "with known argument types. " + << "However, the first argument was of type " << callee_sinfo; + auto callee_params = callee_sinfo->params.value(); + + const TupleStructInfoNode* args = arg_sinfo.as(); + CHECK(args) << "TypeError: " + << "The second argument to `R.call_tir` must be a tuple, " + << "but instead received expression of type " << arg_sinfo; + + // R.call_tir expects the PrimFunc to have three groups of arguments. + // + // 1. Input arguments that are explicitly provided as Relax arguments. + // 2. Output tensor arguments. + // 3. Shape arguments, represented as `T.int64` in the PrimFunc, and + // as an optional ShapeExpr argument in the `relax::Call` node. + // + // In order to determine the return type of `R.call_tir`, we must + // identify the PrimFunc arguments that will be in group (2). + size_t num_input_arguments = args->fields.size(); + size_t num_trailing_int_arguments = 0; + const ShapeStructInfoNode* packed_tuple_sinfo = nullptr; + if (packed_ints_sinfo) { + auto packed_sinfo = packed_ints_sinfo.value(); + packed_tuple_sinfo = packed_sinfo.as(); + CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim()) + << "TypeError: " + << "The third argument to `R.call_tir`, if present, " + << "must be a ShapeTuple with known dimensionality. " + << "However, the argument received was of type " << packed_sinfo; + num_trailing_int_arguments = packed_tuple_sinfo->ndim; + } else { + num_trailing_int_arguments = 0; + } + + CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size()) + << "ValueError: " + << "R.call_tir attempted to call a function using " << num_input_arguments + << " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. " + << "However, the callee only accepts " << callee_params.size() << " arguments in total."; + + // While Relax can specify a distributed tensor, TIR cannot. The + // current implementation does not support determining the output + // shape for `R.dist.call_tir` calls, as it depends on the lowering + // of DistIR into regular Relax. + std::function contains_dtensor = [&contains_dtensor](StructInfo sinfo) -> bool { + if (sinfo.as()) { + return true; + } else if (auto tuple = sinfo.as()) { + return std::any_of(tuple->fields.begin(), tuple->fields.end(), contains_dtensor); + } else { + return false; + } + }; + if (contains_dtensor(arg_sinfo)) { + return NullOpt; + } + + // At this point, the return types are known. However, the shapes + // in `callee_params` may contain dynamic shape parameters that are + // not present in the caller's scope. The `DeriveCallRetStructInfo` + // utility can infer the value of dynamic parameters in + // `FuncStructInfoNode::ret` based on definitions in + // `FuncStructInfoNode::params`, inferring the correct values in the + // caller's scope. + // + // Since the callee of `R.call_tir` is provided with output + // arguments, where `DeriveCallRetStructInfo` requires a callee that + // produces its own outputs, a dummy function signature and + // arguments are used. + + auto dummy_callee_sinfo = [&]() -> FuncStructInfo { + Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); + + for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); + i++) { + dummy_params.push_back(callee_params[i]); + } + + Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); + + if (opt_inplace_indices) { + // For R.call_tir_inplace, the `inplace_indices` are used to + // indicate which elements of the `out_sinfo` will be generated + // as in-place mutation from an input. For any in-place + // mutation, the parameter's StructInfo must be inserted into + // `out_sinfo`. + auto inplace_indices = opt_inplace_indices.value(); + for (size_t i = 0; i < inplace_indices.size(); i++) { + auto inplace_input_index = inplace_indices[i]->value; + if (inplace_input_index >= 0) { + dummy_ret.insert(dummy_ret.begin() + i, callee_params[inplace_input_index]); + } + } + } + + auto dummy_out_sinfo = [&]() -> StructInfo { + if (dummy_ret.size() == 1) { + return dummy_ret[0]; + } else { + return TupleStructInfo(dummy_ret); + } + }(); + + return FuncStructInfo(dummy_params, dummy_out_sinfo); + }(); + + auto dummy_args = [&]() -> Array { + Array dummy_args = args->fields.Map( + [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); + + for (size_t i = 0; i < num_trailing_int_arguments; i++) { + ICHECK(packed_tuple_sinfo); + PrimStructInfo dummy_arg_sinfo = [&]() { + if (packed_tuple_sinfo->values) { + return PrimStructInfo(packed_tuple_sinfo->values.value()[i]); + } else { + return PrimStructInfo(DataType::Int(64)); + } + }(); + dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo)); + } + + return dummy_args; + }(); + + auto derived_ret_sinfo = DeriveCallRetStructInfo( + dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args), + BlockBuilder::Create(NullOpt)); + + return derived_ret_sinfo; +} + StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exactly 1 output struct info."); } CHECK(call->args[0]->IsInstance()) - << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " - << "However, gets " << call->args[0]; - return call->sinfo_args[0]; + << "R.call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " + << "However, the argument " << call->args[0] << " instead has type " + << call->args[0]->GetTypeKey(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + + return explicit_sinfo; } Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { @@ -264,23 +445,37 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "or three arguments [callee, arg_tuple, tir_args], " << "but " << call << " has " << call->args.size() << " arguments."; - Expr arg_expr = call->args[1]; + auto callee = call->args[0]; + CHECK(callee->struct_info_.as()) + << "Operation " << call->op << " expects the first argument to be a TIR callee. " + << "However, the first argument " << callee << " has struct info " << callee->struct_info_; - CHECK(arg_expr->struct_info_.as()) - << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_expr << " has struct info " - << arg_expr->struct_info_ << "."; + Expr arg_tuple = call->args[1]; - if (arg_expr.as()) { - return std::move(call); - } + CHECK(arg_tuple->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; - CHECK(arg_expr.as()) + CHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " - << "However, " << call << " has arguments " << arg_expr + << "However, " << call << " has arguments " << arg_tuple << ", which is neither an in-line tuple, " << "nor a variable binding that may be normalized to an in-line tuple."; + if (call->args.size() > 2) { + Expr packed_ints = call->args[2]; + CHECK(packed_ints->struct_info_.as()) + << "Operation " << call->op << " expects the optional third argument, " + << "if present, to be a ShapeTuple. " + << "However, the third argument " << packed_ints << " has struct info " + << packed_ints->struct_info_; + } + + CHECK_EQ(call->sinfo_args.size(), 1) + << "R.call_tir should have exactly one `sinfo_args` parameter, " + << "which defines the output of the PrimFunc."; + auto unwrap_binding = [&ctx](Expr expr) -> Optional { if (auto var = expr.as()) { if (auto bound_value = ctx->LookupBinding(var.value())) { @@ -290,14 +485,21 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return NullOpt; }; - while (auto unwrapped = unwrap_binding(arg_expr)) { - arg_expr = unwrapped.value(); - } + Tuple new_arg_tuple = [&]() { + // No replacement required. The argument tuple is already + // provided as an in-line tuple. + if (auto opt = arg_tuple.as()) { + return opt.value(); + } + + Expr unwrapped_tuple = arg_tuple; + while (auto unwrapped = unwrap_binding(unwrapped_tuple)) { + unwrapped_tuple = unwrapped.value(); + } - Tuple new_arg_expr = [&]() { // Preferred replacement. The argument tuple is provided as a // variable, but we know the value bound to that variable. - if (auto opt = arg_expr.as()) { + if (auto opt = unwrapped_tuple.as()) { return opt.value(); } @@ -306,20 +508,60 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. Array tuple_elements; - size_t num_fields = Downcast(arg_expr->struct_info_)->fields.size(); + size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); for (size_t i = 0; i < num_fields; i++) { - tuple_elements.push_back(TupleGetItem(arg_expr, i)); + tuple_elements.push_back(TupleGetItem(arg_tuple, i)); } return Tuple(tuple_elements); }(); - auto new_args = call->args; - new_args.Set(1, new_arg_expr); - call.CopyOnWrite()->args = new_args; + if (!new_arg_tuple.same_as(arg_tuple)) { + auto new_args = call->args; + new_args.Set(1, new_arg_tuple); + call.CopyOnWrite()->args = new_args; + } return std::move(call); } +void ValidateCallTIR(Call call) { + // This function is used for validation of `relax.call_tir`, + // along with the variants `relax.call_tir_with_grad` and + // `relax.call_tir_inplace`. Therefore, all error messages should + // be written in terms of `call->op`, and should not explicitly + // reference the `relax.call_tir` operator.` + + auto callee = call->args[0]; + Expr arg_tuple = call->args[1]; + + auto packed_int_sinfo = [&]() -> Optional { + if (call->args.size() <= 2) { + return NullOpt; + } else { + return GetStructInfo(call->args[2]); + } + }(); + + auto opt_inplace_indices = [&]() -> Optional> { + if (const auto* attrs = call->attrs.as()) { + return attrs->inplace_indices; + } else { + return NullOpt; + } + }(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( + GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); + if (inferred_sinfo.defined()) { + CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) + << "TypeError: " + << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " + << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo + << ", but the `out_sinfo` argument was " << explicit_sinfo; + } +} + RELAY_REGISTER_OP("relax.call_tir") .set_num_inputs(3) .add_argument("func", "Expr", "The destination-passing-style function.") @@ -329,6 +571,7 @@ RELAY_REGISTER_OP("relax.call_tir") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, @@ -374,6 +617,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, @@ -514,6 +758,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) + .set_attr("FValidate", ValidateCallTIR) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place // arguments will no longer be live) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index b203b322ab96..612e1459c826 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -1088,8 +1088,7 @@ class TIRFuseMutator : public ExprMutator { const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar); GlobalVar new_gvar(old_gvar->name_hint); - UpdateStructInfo(new_gvar, - FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type))); + UpdateStructInfo(new_gvar, GetStructInfo(prim_func)); mod->Remove(old_gvar); updates->Add(new_gvar, prim_func); diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index e1f45d278d6c..865051b0b4b9 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -512,13 +512,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -712,13 +710,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -1278,13 +1274,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18 = R.call_tir( cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16") @@ -1449,13 +1443,11 @@ def foo( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape1"), diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 7deddfd28eb9..c0b962c3f3a0 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import pytest + import tvm import tvm.testing + from tvm import relax as rx from tvm import tir -from tvm.script import relax as R -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import ir as I +from tvm.script import ir as I, relax as R, tir as T m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -702,5 +702,511 @@ def is_bfloat16_dtype(tensor: T.handle) -> T.bool: assert rx.analysis.well_formed(Module) +def test_call_tir_with_matching_arguments(): + """R.call_tir is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_input_ndim(): + """Arguments to R.call_tir must have the correct dimensionality + + Here, the `add_one` function expects a 1-d input tensor, but is + called with a 2-d tensor. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([4, 4], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_ndim(): + """Output shape R.call_tir must have the correct dimensionality + + Here, the `add_one` function requires a 1-d output tensor, but is + provided with a 2-d tensor. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_shape(): + """Arguments to R.call_tir must have the correct shape + + Here, the `add_one` function expects an input tensor with 16 + elements, but is called with an input tensor with 32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([32], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_shape(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor with 16 + elements, but is provided an output tensor with 32 elements. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_dtype(): + """Arguments to R.call_tir must have the correct dtype + + Here, the `add_one` function expects an input tensor containing + float16 value, but is called with an input tensor containing + float32 values. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float32")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_dtype(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor that may be + populated with float16 values, but is provided an output tensor + that may be populated with float32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is determined by the `out_sinfo` parameter. + + Inability to verify the output shape does not mean that the output + shape is invalid. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not supported") +def test_call_tir_with_incorrect_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. Even though the + IRModule will not provide well-defined output due to the + out-of-bounds read from buffer A, catching this error is beyond + the current scope of the Relax well-formed checker. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_incorrect_dimensionality_of_output_shape(): + """Dimensionality may be verified + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. + + Even though the output shape may not be inferred from the input + arguments, the output dimensionality can still be inferred from + the PrimFunc signature. The IRModule below is ill-formed, because + the PrimFunc requires a 2-d output argument, but is provided with + a 3-d output argument. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not yet supported") +def test_call_tir_output_shape_with_mixed_static_and_dynamic(): + """Some dimensions of the R.call_tir output shape may be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is taken from the `out_sinfo` parameter. + + Identifying this failure mode is not yet supported in the current + implementation. This is because the output is inferred as + `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo` + is a 3-d tensor. The mismatch in the first dimension is not yet + counted, because the entire tensor shape is removed by + `EraseToWellDefined`. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([256], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [16, M, N], dtype="float16") + + for i, j, k in T.grid(16, M, N): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi * N * M + vj * N + vk] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + This unit test is identical to the above test + `test_call_tir_with_correct_inferred_dynamic_output_shape`, except + that the output shape is explicitly specified as `[64]`, which is + caught as a mismatch from the expected output shape. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_dtensor_arguments(): + """R.call_tir and R.dist.call_tir share the same operation + + Both `R.call_tir` and `R.dist.call_tir` produce the same + "relax.call_tir" operation, differing only in the StructInfo of + their arguments. Normalization of "relax.call_tir" must handle + `R.DTensor` arguments. + + """ + + # from tvm.script.parser import relax as R + + @I.ir_module + class Module: + I.module_attrs({"device_num": 4}) + I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0, 4))]}) + + @R.function + def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")): + B = R.dist.call_tir( + Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]") + ) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_correct_shapes(): + """R.call_tir_inplace is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([16], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_incorrect_shapes(): + """R.call_tir_inplace is ill-formed when output shape does not match input""" + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([32], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_some_allocated_outputs(): + """R.call_tir_inplace may contain some non-inplace outputs""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")): + out = R.call_tir_inplace( + Module.add_one, + (A, B), + inplace_indices=[-1, 1], + out_sinfo=[ + R.Tensor([16], "float16"), + R.Tensor([32], "float16"), + ], + ) + return out + + @T.prim_func + def add_one( + A: T.Buffer(16, "float16"), + B: T.Buffer(32, "float16"), + C: T.Buffer(16, "float16"), + ): + for i in range(32): + with T.block("inplace_B"): + vi = T.axis.remap("S", [i]) + B[vi] = B[vi] + T.float16(1.0) + + for i in range(16): + with T.block("output_C"): + vi = T.axis.remap("S", [i]) + C[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 64d5c7381171..6005ecb0fa58 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -43,6 +43,7 @@ def normalize(func: rx.Function) -> rx.Function: """ Normalize the expr to fill in the checked_type_ and struct_info fields everywhere """ + # using a default mutator to use the BlockBuilder's normalizer, # which oddly differs from the Normalize pass @rx.expr_functor.mutator @@ -435,9 +436,13 @@ def test_call_tir(): @tvm.script.ir_module class TestCallTIR: @T.prim_func - def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + def addone(A_handle: T.handle, B_handle: T.handle) -> None: + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.func_attr(({"global_symbol": "addone"})) - for i, j in T.grid(16, 16): + for i, j in T.grid(m, n): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 8d5eb07c7858..cd6e285de499 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -172,8 +172,8 @@ def tir_id(x: T.handle, y: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): @@ -185,9 +185,9 @@ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) - C = T.match_buffer(z, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") + C = T.match_buffer(z, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 03a3beb2f27e..7a3b65cea10e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -72,7 +72,7 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) lv2 = R.call_tir( - cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) + cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) ) gv = (lv1, lv2) R.output(gv) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index d83f83f4e188..21e1d82d28b5 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -114,9 +114,10 @@ def main( with db: opt_model = torch.compile(model, backend=relax_dynamo()) inp = torch.randn(10, 100) - tvm.testing.assert_allclose( - opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 - ) + + default_output = model(inp).detach().numpy() + optimized_output = opt_model(inp).detach().numpy() + tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5) def test_relax_dynamo_dynamic(): diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 40624790cb5a..6a337b34c114 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -570,10 +570,18 @@ def test_tensor_ir_op(): @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, - offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle, + # Scalar arguments must be specified after tensor arguments, + # including the output tensor arguments + # + # TODO(Lunderberg): Update + # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue` + # instead of `tir_vars`, so that the order can be consistent + # between the function definition and the arguments in + # `op.tensor_ir_op`. + offset: T.int64, ): batch_size = T.int64() seq_len = T.int64() @@ -601,7 +609,7 @@ def test(self, qkv: Tensor, offset: tir.Var): @I.ir_module class Expected: @T.prim_func(private=True) - def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle): + def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, offset: T.int64): batch_size, seq_len = T.int64(), T.int64() qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16") q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16") @@ -669,10 +677,11 @@ class Model(Module): def test( self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int ): - tensor_expr_op_out = op.tensor_ir_op( + tensor_expr_op_out = op.tensor_ir_inplace_op( inplace_take, "inplace_take", args=[embedding_table, input_ids, embedding_dst, offset], + inplace_indices=[2], out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype), ) return tensor_expr_op_out @@ -719,10 +728,11 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv1 = R.call_tir( + lv1 = R.call_tir_inplace( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), + inplace_indices=[2], tir_vars=R.shape([offset_1]), ) gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ee2df866fb35..e3274aea886a 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -86,7 +86,11 @@ def test_call_tir_rewrite(): @tvm.script.ir_module class TestCallTIRRewrite: @T.prim_func - def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + def exp(A_handle: T.handle, B_handle: T.handle): + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.evaluate(0) @R.function diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 65970d64550e..0ddf985ec4ba 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -277,18 +277,26 @@ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"): def test_unused_relax_func_symbolic_shape(): # Test with relax function w/ symbolic shape. - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func - def tir_add( - x: T.Buffer((16, 16), "float32"), - y: T.Buffer((16, 16), "float32"), - z: T.Buffer((16, 16), "float32"), + def tir_matmul( + x_handle: T.handle, + y_handle: T.handle, + z_handle: T.handle, ) -> None: - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - z[vi, vj] = x[vi, vj] + y[vi, vj] + m = T.int64() + n = T.int64() + k = T.int64() + x = T.match_buffer(x_handle, (m, n), "float32") + y = T.match_buffer(y_handle, (n, k), "float32") + z = T.match_buffer(z_handle, (m, k), "float32") + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + z[vi, vj] = 0.0 + z[vi, vj] = z[vi, vj] + x[vi, vk] * y[vk, vj] @R.function(private=True) def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): @@ -298,7 +306,7 @@ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "flo @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): m, k = T.int64(), T.int64() - gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) + gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 mod = InputModule @@ -306,7 +314,7 @@ def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) new_mod = DeadCodeElimination()(mod) assert check_if_func_exists(new_mod, "main") - assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "tir_matmul") assert not check_if_func_exists(new_mod, "unused_func") diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 17bf58613294..9ad66bec012a 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -875,7 +875,7 @@ class Module: def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): cls = Module with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) R.output(gv1) return gv1 @@ -955,7 +955,7 @@ def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.func_attr({"Primitive": 1}) cls = Expected with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) R.output(gv) return gv @@ -1452,7 +1452,7 @@ def main( R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), - ) + ), ): with R.dataflow(): x0 = x[0] @@ -1486,7 +1486,7 @@ def main( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), - ) + ), ) -> R.Tensor((2,), dtype="float32"): cls = Expected with R.dataflow(): diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 1582526042f1..a07875fcdae6 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -696,10 +696,10 @@ def test_ignore_call_tir(): class Conv2dReLUCallTIR: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) @@ -714,7 +714,7 @@ def main( relu1 = R.call_tir( Conv2dReLUCallTIR.relu, (conv1,), - R.Tensor((64, 64, 56, 56), "float32"), + R.Tensor((1, 64, 56, 56), "float32"), ) R.output(relu1) @@ -724,11 +724,11 @@ def main( class Conv2dReLUCallTIR_partitioned: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data[i, j, k, l]) @@ -754,7 +754,7 @@ def fused_relax_nn_conv2d( def main( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), - ) -> R.Tensor((64, 64, 56, 56), dtype="float32"): + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLUCallTIR_partitioned with R.dataflow(): lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( @@ -763,7 +763,7 @@ def main( relu1 = R.call_tir( cls.relu, (lv,), - out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"), + out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"), ) R.output(relu1) return relu1 diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 278ac825f7a7..87a5698f1bf8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -43,7 +43,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -124,7 +124,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -209,7 +209,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -298,7 +298,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -441,8 +441,8 @@ def main_transform_params( @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -479,8 +479,8 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -511,7 +511,7 @@ def main_transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -637,7 +637,7 @@ def transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -691,7 +691,7 @@ def test_duplicate_outputs(): class Before: @R.function def main_transform_params( - params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")) + params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")), ): R.func_attr({"relax.force_pure": True}) param0 = params[0] @@ -966,7 +966,7 @@ def transform_params( class Expected: @R.function def transform_params( - fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object) + fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object), ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")): R.func_attr({"num_input": 1}) m = T.int64() diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index f7befd3b886a..5a7d76d8fe41 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -252,11 +252,15 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): cls = Module with R.dataflow(): - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) - z = R.add(y, R.const(1, "float32")) + y = R.call_tir( + cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), dtype="float16") + ) + z = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -290,10 +294,14 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): with R.dataflow(): - y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x, R.shape([2, 4, 3])) - z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1, "float32")) + y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape( + x, R.shape([1, 8, 16, 128]) + ) + z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -383,7 +391,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): cls = Module with R.dataflow(): @@ -444,7 +452,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): with R.dataflow(): lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0] @@ -735,7 +743,6 @@ def add( z_handle: T.handle, N: T.int64, ): - y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32") y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32") z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32") diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index ea99d49270a1..64f2efd4af9e 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -77,7 +77,7 @@ def test_mismatch_cast_dims_and_ndim(): @R.function def f( - x: R.Tensor((2, 3), "float32", ndim=3) + x: R.Tensor((2, 3), "float32", ndim=3), ): # error: ndim and the shape dims are mismatch return x @@ -961,11 +961,11 @@ def test_call_tir_with_tir_var(): class Module: @R.function def main( - dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), "float32") ) -> R.Tensor(("n * 2",), "float32"): n = T.int64() cls = Module - y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), tir_vars=(n,)) return y @T.prim_func @@ -2171,7 +2171,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(2) # Make sure prim_value is 2 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(2), # Make sure prim_value is 2 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2203,7 +2205,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(1) # Make sure prim_value is 1 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(1), # Make sure prim_value is 1 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2372,7 +2376,6 @@ def explicit_sinfo( B: R.Tensor(["N"], "float32"), cond: R.Prim("bool"), ) -> R.Tensor(["N"], "float32"): - N = T.int64() if cond: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 30fd06d4f14d..ecf33aa9da1e 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -988,8 +988,10 @@ class ModA: I.module_attrs({"system_lib_prefix": "libA_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(0) @R.function @@ -1003,8 +1005,10 @@ class ModB: I.module_attrs({"system_lib_prefix": "libB_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(1) @R.function From 4eafd00cada11a03c2a949cc6fd0e5d9a06e013b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 Sep 2024 09:46:00 -0500 Subject: [PATCH 117/202] [Relax][Bugfix] FCallPacked not checked in CodegenVMTIR (#17073) Prior to this commit, an operator's `FCallPacked` attribute, used to specify a 1:1 mapping between a relax operator and a `PackedFunc` that implements it, was only checked in `CodegenVM`. Any operator with `FCallPacked` would raise an error when compiled using `CodegenVMTIR`. This commit removes the `FCallPacked` handling from `CodegenVM` altogether, and instead checks for this attribute as part of `LegalizeOps`. This provides the same functionality across both backends. --- src/relax/backend/vm/codegen_vm.cc | 24 +--- src/relax/backend/vm/codegen_vm_tir.cc | 24 +--- src/relax/transform/legalize_ops.cc | 25 ++-- tests/python/relax/test_relax_operators.py | 139 ++++++++++++--------- 4 files changed, 101 insertions(+), 111 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 1c795594629e..ca2d4d4fdb2e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,21 +45,6 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; -namespace { -// Helper function to get the function name of the registered packed function implementation of -// relax operator. -FCallPacked GetPackedFuncName(const Call& call) { - static auto op_map = Op::GetAttrMap("FCallPacked"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op]; - } - } - return {}; -} -} // namespace - /*! * \brief A class to generate VM executable for Relax functions. */ @@ -156,14 +141,7 @@ class CodeGenVM : public ExprFunctor { // allocate dst register. RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); if (call->op.as()) { - // special case generate for the intrinsics whose attribute fields - // cannot be represented by args in the CallNode - FCallPacked name = GetPackedFuncName(call); - if (!name.empty()) { - // If the operator has a registered packed function implementation, emit call to that packed - // function. - EmitPackedFuncCall(call, name, dst_reg); - } else if (call_node->op == call_builtin_with_ctx_op_) { + if (call_node->op == call_builtin_with_ctx_op_) { // TODO(relax-team) migrate most handling of op to // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. EmitCallBuiltinWithCtx(call, dst_reg); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 5e6a1c3f8442..a92cf7c749a0 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,21 +44,6 @@ namespace relax_vm { using vm::VMFuncInfo; -namespace { -// Helper function to get the function name of the registered packed function implementation of -// relax operator. -FCallPacked GetPackedFuncName(const Call& call) { - static auto op_map = Op::GetAttrMap("FCallPacked"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op]; - } - } - return {}; -} -} // namespace - /*! * \brief A class to generate VMTIR for Relax functions. * @@ -247,14 +232,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - // special case generate for the intrinsics whose attribute fields - // cannot be represented by args in the CallNode - FCallPacked name = GetPackedFuncName(call); - if (name.size()) { - // If the operator has a registered packed function implementation, emit call to that packed - // function. - EmitCallPacked(name, VisitArray(call->args), dst_reg); - } else if (call_node->op == call_builtin_with_ctx_op_) { + if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 34902fa0f8b6..4a6b44bf2839 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -224,6 +224,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); @@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator { } auto op = GetRef(op_node); - bool can_legalize = [&]() -> bool { + bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; if (!requires_arg_shapes) { // This operator does not require its arguments to have a @@ -299,23 +300,31 @@ class LegalizeMutator : public ExprMutator { return true; }(); - if (!can_legalize) { - return visited_call; - } - FLegalize legalization_func; - if (auto opt_custom_legalize = cmap_.Get(op->name)) { + if (auto opt_custom_legalize = cmap_.Get(op->name); + opt_custom_legalize && shapes_are_known_if_required) { // First choice, use a custom legalization function legalization_func = opt_custom_legalize.value(); - } else if (legalize_map.count(op)) { + } else if (legalize_map.count(op) && shapes_are_known_if_required) { // Second choice, use a default legalization legalization_func = legalize_map[op]; + } else if (call_packed_map.count(op)) { + // Third choice, use an explicit FCallPacked replacement. This does not require the shape + String packed_func_name = call_packed_map[op]; + legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { + return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); + }; } else { // No legalization. if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op && op != call_pure_packed_op) { - LOG(WARNING) << "No legalization func for " << op->name << " is found."; + if (shapes_are_known_if_required) { + LOG(WARNING) << "No legalization func for " << op->name << " is found."; + } else { + LOG(WARNING) << "Cannot legalize " << visited_call + << ", missing known shapes for arguments and return value"; + } } return visited_call; } diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 41618a32cb55..fcb8727d8508 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -27,6 +27,8 @@ from tvm._ffi.base import TVMError from tvm.script import ir as I, relax as R, tir as T +exec_mode = tvm.testing.parameter("bytecode", "compiled") + @tvm.script.ir_module class InputModule: @@ -37,7 +39,7 @@ def foo(x: R.Tensor(("m", "n"), "int64")): return y, y_sorted -def run_cpu(mod, func_name, *args): +def run_cpu(mod, func_name, *args, exec_mode): if isinstance(mod, relax.Function): func = mod args = [func_name, *args] @@ -45,17 +47,17 @@ def run_cpu(mod, func_name, *args): mod = tvm.IRModule.from_expr(func) target = tvm.target.Target("llvm") - ex = relax.build(mod, target) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) return vm[func_name](*args) -def test_unique(): +def test_unique(exec_mode): # TODO(prakalp): also add test for compiling and running on cuda device. data_numpy = np.random.randint(0, 16, (16, 16)) data = tvm.nd.array(data_numpy) - result, result_sorted = run_cpu(InputModule, "foo", data) + result, result_sorted = run_cpu(InputModule, "foo", data, exec_mode=exec_mode) expected_output_sorted, indices = np.unique(data_numpy, return_index=True) expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] @@ -81,12 +83,17 @@ def foo(x: R.Tensor((), "int32")): return x -def test_print(): +def test_print(exec_mode): try: stdout = sys.stdout with tempfile.TemporaryFile(mode="w+") as test_out: sys.stdout = test_out - run_cpu(PrintTest, "foo", tvm.nd.array(np.array(1).astype("int32"))) + run_cpu( + PrintTest, + "foo", + tvm.nd.array(np.array(1).astype("int32")), + exec_mode=exec_mode, + ) test_out.seek(0) printed_text = str(test_out.read()) expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n" @@ -95,65 +102,65 @@ def test_print(): sys.stdout = stdout -def test_assert_passes(): +def test_assert_passes(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True)) return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_passes_with_format_args(): +def test_assert_passes_with_format_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True), x, format="You won't see me") return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails(): +def test_assert_fails(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False)) return x with pytest.raises(AssertionError, match="Assertion Failed"): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_message(): +def test_assert_fails_with_message(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), format="I failed...") return x with pytest.raises(AssertionError, match="I failed..."): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_args(): +def test_assert_fails_with_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), [x, x]) return x with pytest.raises(AssertionError, match="5, 5"): - run_cpu(func, tvm.nd.array(np.array(5).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(5).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_formatted_args(): +def test_assert_fails_with_formatted_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), x, format="Number: {}") return x with pytest.raises(AssertionError, match="Number: 6"): - run_cpu(func, tvm.nd.array(np.array(6).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(6).astype("int32")), exec_mode=exec_mode) -def test_assert_on_argument_passes(): +def test_assert_on_argument_passes(exec_mode): @R.function(pure=False) def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) @@ -161,10 +168,10 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): condition = tvm.nd.array(np.array(True)) x = tvm.nd.array(np.array(5).astype("int32")) - run_cpu(func, condition, x) + run_cpu(func, condition, x, exec_mode=exec_mode) -def test_assert_on_argument_fails(): +def test_assert_on_argument_fails(exec_mode): @R.function(pure=False) def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) @@ -173,10 +180,10 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): condition = tvm.nd.array(np.array(False)) x = tvm.nd.array(np.array(5).astype("int32")) with pytest.raises(AssertionError): - run_cpu(func, condition, x) + run_cpu(func, condition, x, exec_mode=exec_mode) -def test_assert_on_symbolic_var_passes(): +def test_assert_on_symbolic_var_passes(exec_mode): @R.function(pure=False) def func(x: R.Tensor(["N"], "int32")): N = T.int64() @@ -184,10 +191,10 @@ def func(x: R.Tensor(["N"], "int32")): return x x = tvm.nd.array(np.arange(8, dtype="int32")) - run_cpu(func, x) + run_cpu(func, x, exec_mode=exec_mode) -def test_assert_on_symbolic_var_fails(): +def test_assert_on_symbolic_var_fails(exec_mode): @R.function(pure=False) def func(x: R.Tensor(["N"], "int32")): N = T.int64() @@ -196,7 +203,7 @@ def func(x: R.Tensor(["N"], "int32")): x = tvm.nd.array(np.arange(10, dtype="int32")) with pytest.raises(AssertionError): - run_cpu(func, x) + run_cpu(func, x, exec_mode=exec_mode) @tvm.script.ir_module @@ -223,23 +230,31 @@ def get_constant_shape() -> R.Shape((2, 2)): return R.shape_of(x) -def test_op_shape_of(): - unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape") +def test_op_shape_of(exec_mode): + unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape", exec_mode=exec_mode) assert unit_shape == tvm.runtime.ShapeTuple([]) - const_shape = run_cpu(ShapeOfTest, "get_constant_shape") + const_shape = run_cpu(ShapeOfTest, "get_constant_shape", exec_mode=exec_mode) assert const_shape == tvm.runtime.ShapeTuple([2, 2]) - scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32"))) + scalar_shape = run_cpu( + ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")), exec_mode=exec_mode + ) assert scalar_shape == tvm.runtime.ShapeTuple([]) tensor_shape = run_cpu( - ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")) + ShapeOfTest, + "get_shape", + tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")), + exec_mode=exec_mode, ) assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3]) constrained_shape = run_cpu( - ShapeOfTest, "get_constrained_shape", tvm.nd.array(np.zeros((1,)).astype("int32")) + ShapeOfTest, + "get_constrained_shape", + tvm.nd.array(np.zeros((1,)).astype("int32")), + exec_mode=exec_mode, ) assert constrained_shape == tvm.runtime.ShapeTuple([1]) @@ -257,7 +272,7 @@ def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1): return R.shape_to_tensor(shape) -def test_op_shape_to_tensor(): +def test_op_shape_to_tensor(exec_mode): # Check struct info isinstance(ShapeToTensorTest["const_shape"].body.struct_info, tvm.relax.TensorStructInfo) assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1 @@ -265,24 +280,32 @@ def test_op_shape_to_tensor(): assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1 # Check its functionality - out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2])) + out2d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode + ) assert isinstance(out2d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out2d.numpy(), np.array([3, 2])) - out3d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2])) + out3d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]), exec_mode=exec_mode + ) assert isinstance(out3d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out3d.numpy(), np.array([3, 3, 2])) - out4d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2])) + out4d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2]), exec_mode=exec_mode + ) assert isinstance(out4d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2])) - outs = run_cpu(ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2])) + outs = run_cpu( + ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode + ) assert isinstance(outs, tvm.runtime.ndarray.NDArray) assert np.array_equal(outs.numpy(), np.array([3, 2])) -def test_op_call_pure_packed(): +def test_op_call_pure_packed(exec_mode): @tvm.script.ir_module class CallPureTest: @R.function @@ -294,11 +317,11 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr)) + copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_op_call_inplace_packed(): +def test_op_call_inplace_packed(exec_mode): # in this case we can use the same test as above @tvm.script.ir_module class CallInplaceTest: @@ -312,7 +335,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): ) return z - @tvm.register_func("test.inplace.add") + @tvm.register_func("test.inplace.add", override=True) def inplace_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -340,11 +363,13 @@ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): arr_b = np.random.rand(3, 4).astype("float32") sum = arr_a + arr_b tvm_arr_a = tvm.nd.array(arr_a) - result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b)) + result = run_cpu( + CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b), exec_mode=exec_mode + ) assert result == tvm_arr_a assert (result.numpy() == sum).all() - @tvm.register_func("test.inplace.tuple_add") + @tvm.register_func("test.inplace.tuple_add", override=True) def inplace_tuple_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -374,14 +399,14 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") sum = arr_a + arr_b tvm_arr_a = tvm.nd.array(arr_a) tvm_arr_b = tvm.nd.array(arr_b) - result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b) + result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b, exec_mode=exec_mode) assert result[0] == tvm_arr_a assert (result[0].numpy() == sum).all() assert result[1] != tvm_arr_a and result[1] != tvm_arr_b assert (result[1].numpy() == sum).all() -def test_op_to_device(): +def test_op_to_device(exec_mode): @tvm.script.ir_module class CallToDevice: @R.function @@ -397,11 +422,11 @@ def to_dev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr)) + copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_op_to_vdevice(): +def test_op_to_vdevice(exec_mode): @tvm.script.ir_module class ToVDevice: I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) @@ -414,11 +439,11 @@ def to_vdev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr)) + copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_scalar_tensor_as_branch_condition(): +def test_scalar_tensor_as_branch_condition(exec_mode): """The condition of a branch may be a scalar tensor""" @R.function @@ -429,14 +454,14 @@ def func(condition: R.Tensor((), "bool")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.array(True))) + res = run_cpu(func, tvm.nd.array(np.array(True)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.array(False))) + res = run_cpu(func, tvm.nd.array(np.array(False)), exec_mode=exec_mode) assert res == 10 -def test_prim_value_as_branch_condition(): +def test_prim_value_as_branch_condition(exec_mode): """The condition may be a PrimValue""" @R.function @@ -447,14 +472,14 @@ def func(condition: R.Prim("bool")): out = R.prim_value(10) return out - res = run_cpu(func, True) + res = run_cpu(func, True, exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, False) + res = run_cpu(func, False, exec_mode=exec_mode) assert res == 10 -def test_computed_prim_value_as_branch_condition(): +def test_computed_prim_value_as_branch_condition(exec_mode): """The R.Prim condition may be computed within the function""" @R.function @@ -466,10 +491,10 @@ def func(x: R.Tensor(["N"], "int64")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.arange(16))) + res = run_cpu(func, tvm.nd.array(np.arange(16)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.arange(20))) + res = run_cpu(func, tvm.nd.array(np.arange(20)), exec_mode=exec_mode) assert res == 10 From ec28b6794b93b90bfdaf3b281cd7f4c3b4a1fbf8 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 6 Sep 2024 23:48:49 +0900 Subject: [PATCH 118/202] [Apps] Remove mxnet dependency from /apps/android_camera/models (#17297) * use torchvision's resnet18 instead of mxnet * cleanup import statements --- apps/android_camera/models/prepare_model.py | 31 +++++++++++---------- apps/android_camera/models/requirements.txt | 3 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/apps/android_camera/models/prepare_model.py b/apps/android_camera/models/prepare_model.py index 9f2cbbdd6d1f..5fd99967aea3 100644 --- a/apps/android_camera/models/prepare_model.py +++ b/apps/android_camera/models/prepare_model.py @@ -15,18 +15,16 @@ # specific language governing permissions and limitations # under the License. -import logging -import pathlib -from pathlib import Path -from typing import Union +import json import os from os import environ -import json +from pathlib import Path +from typing import Union import tvm import tvm.relay as relay -from tvm.contrib import utils, ndk, graph_executor as runtime -from tvm.contrib.download import download_testdata, download +from tvm.contrib import ndk +from tvm.contrib.download import download, download_testdata target = "llvm -mtriple=arm64-linux-android" target_host = None @@ -50,15 +48,18 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False): def get_model(model_name, batch_size=1): if model_name == "resnet18_v1": - import mxnet as mx - from mxnet import gluon - from mxnet.gluon.model_zoo import vision + import torch + import torchvision - gluon_model = vision.get_model(model_name, pretrained=True) - img_size = 224 - data_shape = (batch_size, 3, img_size, img_size) - net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) - return (net, params) + weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + torch_model = torchvision.models.resnet18(weights=weights).eval() + input_shape = [1, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(torch_model, input_data) + + input_infos = [("data", input_data.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) + return (mod, params) elif model_name == "mobilenet_v2": import keras from keras.applications.mobilenet_v2 import MobileNetV2 diff --git a/apps/android_camera/models/requirements.txt b/apps/android_camera/models/requirements.txt index dbf496b2d968..3e35efdeb66e 100644 --- a/apps/android_camera/models/requirements.txt +++ b/apps/android_camera/models/requirements.txt @@ -1,4 +1,5 @@ keras==2.9 -mxnet scipy tensorflow==2.9.3 +torch +torchvision From ff884b609a2eb94fef1f061bff0ec867b79d4ba0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 Sep 2024 11:28:28 -0500 Subject: [PATCH 119/202] [Relax][Transform] Handle tuple return in RemoveUnusedOutputs (#17253) * [Relax][Transform] Handle tuple return in RemoveUnusedOutputs Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass only marked a tuple element as used if it occurred in a `TupleGetItem` node. This ignored use cases where a tuple is used as an aggregate object, such as returning a tuple from a function. This would collect incorrect results for a Relax function that calls a subroutine, receives a tuple as the return value of the subroutine, then returns that tuple. This commit updates `RemoveUnusedOutputs` to look for usage of a tuple object, not just for usage in `TupleGetItem`. Closes https://github.com/apache/tvm/issues/17247 --- src/relax/transform/remove_unused_outputs.cc | 59 ++++++++++++------- .../test_transform_remove_unused_outputs.py | 20 +++++++ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index e3bf12382c67..9a5c31e79ba0 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor { } void VisitExpr_(const TupleGetItemNode* op) override { - Expr tuple = UnwrapBindings(op->tuple); - - if (auto call = tuple.as()) { - if (auto opt_callee = call->op.as()) { - auto callee = opt_callee.value(); - if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) { - auto& used_indices = it->second; - - CHECK_GE(op->index, 0) << "IndexError: " - << "Indices for TupleGetItem must be non-negative, " - << "but expression " << GetRef(op) - << " uses a tuple index of " << op->index; - size_t index = op->index; - - CHECK_LT(index, used_indices.size()) - << "IndexError: " - << "Indices for TupleGetItem must be less than the size of the tuple, " - << "but expression " << GetRef(op) << " uses a tuple index of " << op->index - << " for a tuple of size " << used_indices.size(); - used_indices[index] = true; + if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) { + auto& used_indices = *usage_mask_ptr; + + CHECK_GE(op->index, 0) << "IndexError: " + << "Indices for TupleGetItem must be non-negative, " + << "but expression " << GetRef(op) << " uses a tuple index of " + << op->index; + size_t index = op->index; + + CHECK_LT(index, used_indices.size()) + << "IndexError: " + << "Indices for TupleGetItem must be less than the size of the tuple, " + << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << " for a tuple of size " << used_indices.size(); + used_indices[index] = true; + } + } + + void VisitExpr_(const VarNode* op) override { + if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef(op))) { + auto& usage_mask = *usage_mask_ptr; + for (size_t i = 0; i < usage_mask.size(); i++) { + usage_mask[i] = true; + } + } + } + + std::vector* GetCalleeUsageMask(Expr expr) { + if (!expr->struct_info_.as()) { + return nullptr; + } + + expr = UnwrapBindings(expr); + if (auto call = expr.as()) { + if (auto callee = call->op.as()) { + if (auto it = output_usage_mask_.find(callee.value()); it != output_usage_mask_.end()) { + return &it->second; } } } + + return nullptr; } Expr UnwrapBindings(Expr expr) const { diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py index c0405ca58d00..365ce1695d0e 100644 --- a/tests/python/relax/test_transform_remove_unused_outputs.py +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -119,5 +119,25 @@ def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")] return (A, C) +class TestReturnTuple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16, 16], "int32")): + B = R.add(A, A) + out_tuple = Before.func(B) + return out_tuple + + @R.function(private=True) + def func( + B: R.Tensor([16, 16], "int32") + ) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")): + C = R.multiply(B, B) + D = R.add(B, B) + return (C, D) + + Expected = Before + + if __name__ == "__main__": tvm.testing.main() From dcd32ac6368f0d34b5c7823d90aa5a701e3728e8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 7 Sep 2024 01:01:53 -0400 Subject: [PATCH 120/202] [DOCS] Minor fix typo in developer howto guide (#17343) This PR provides a minor fix of developer howto guide. --- docs/how_to/dev/index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/how_to/dev/index.rst b/docs/how_to/dev/index.rst index c70832358a41..c815871b4147 100644 --- a/docs/how_to/dev/index.rst +++ b/docs/how_to/dev/index.rst @@ -15,8 +15,8 @@ specific language governing permissions and limitations under the License. -Develope Apache TVM -=================== +Development Guides +================== This section contains a collection of tips about how to work on various areas of the TVM stack. From 521ab47edf1a2b25b6614d64df5d9f6133dfa329 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sun, 8 Sep 2024 18:40:49 +0800 Subject: [PATCH 121/202] [MSC] Reconstruct tensorrt module (#17344) * reconstruct tensorrt * format fix --- .../contrib/msc/core/frontend/translate.py | 2 +- .../framework/tensorrt/frontend/translate.py | 5 +- .../framework/tensorrt/transform/pattern.py | 31 +- .../framework/tensorrt/transform/transform.py | 13 +- .../msc/core/transform/rewrite_utils.cc | 58 ++ .../msc/core/transform/rewrite_utils.h | 72 ++ src/contrib/msc/core/utils.cc | 19 +- src/contrib/msc/core/utils.h | 4 +- .../msc/framework/tensorrt/tensorrt_opcode.cc | 6 +- .../framework/tensorrt/transform_tensorrt.cc | 668 +++++++++++------- .../test_msc/test_translate_tensorrt.py | 47 +- 11 files changed, 642 insertions(+), 283 deletions(-) create mode 100644 src/contrib/msc/core/transform/rewrite_utils.cc create mode 100644 src/contrib/msc/core/transform/rewrite_utils.h diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 63b4424524eb..cea021ade331 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -330,7 +330,7 @@ def _is_target_func(func): msc_mod = _partition_mod(mod) func_names = [var.name_hint for var, func in msc_mod.functions.items() if _is_target_func(func)] - if not trans_config.get("allow_incomplete", False): + if trans_config.get("as_complete", True): assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod) BYOCChecker().check(func_names, msc_mod[entry]) diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py index 8758fdb63079..4a02b02728de 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py @@ -49,7 +49,10 @@ def transform_for_tensorrt( return tvm.transform.Sequential( [ msc_transform.SetExprName(), - trt_transform.TransformTensorRT(trans_config.get("version")), + trt_transform.TransformTensorRT( + version=trans_config.get("version"), + linear_to_conv=trans_config.get("linear_to_conv", False), + ), relax.transform.FoldConstant(), ] )(mod) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py index 8eea3f7081a7..17aee690e370 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -136,12 +136,22 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] = None) -> bool: return True if isinstance(expr, relax.Tuple): return all(_check_expr(field) for field in expr.fields) - if any(i < 0 for i in expr.struct_info.shape.values): - return False - dtypes = dtypes or ("float32", "float16") - if expr.struct_info.dtype not in dtypes: - return False - return True + dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool") + + def _check(sinfo): + if not sinfo.shape or sinfo.dtype not in dtypes: + return False + unknown_dim = 0 + for s in sinfo.shape.values: + if isinstance(s, (tvm.tir.Var, tvm.tir.Any)): + unknown_dim += 1 + elif isinstance(s, tvm.tir.IntImm) and s < 0: + unknown_dim += 1 + return unknown_dim <= 1 + + if isinstance(expr.struct_info, relax.TupleStructInfo): + return all(_check(s) for s in expr.struct_info.fields) + return _check(expr.struct_info) def _basic_check(context: PatternCheckContext) -> bool: @@ -216,8 +226,7 @@ def _reshape_check(context: PatternCheckContext) -> bool: Whether the pattern is correct. """ - dtypes = ("float32", "float16", "int32") - if any(not _check_expr(context.annotated_expr[key], dtypes) for key in ["input_0", "out"]): + if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]): return False return True @@ -323,16 +332,18 @@ def get_patterns(target) -> List[Pattern]: "nn.avg_pool2d": ["input"], "nn.conv2d": ["input", "constant"], "nn.max_pool2d": ["input"], + "astype": ["input"], "concat": ["input"], "clip": ["input", "input", "input"], "image.resize2d": ["input", "input"], "matmul": ["input", "input"], "permute_dims": ["input"], - "strided_slice": ["input"], + "strided_slice": ["input", "input", "input", "input", "input"], + "topk": ["input"], } activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"] reduce_ops = ["max", "min", "mean", "sum"] - unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] + unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] elemwise_ops = [ "add", "divide", diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py index d6f15c43dacd..cf4d4b9f33ec 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py @@ -25,18 +25,25 @@ from tvm.contrib.msc.core import utils as msc_utils -def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass: +def TransformTensorRT( + version: List[int] = None, linear_to_conv: bool = False +) -> tvm.ir.transform.Pass: """Transform the Function to fit TensorRT. Parameters ---------- version: list The tensorrt version. + linear_to_conv: bool + Whether to cast linear to conv2d Returns ------- ret: tvm.ir.transform.Pass """ - version = version or msc_utils.get_version(MSCFramework.TENSORRT) - return relax_api.TransformTensorRT(version) # type: ignore + config = { + "version": version or msc_utils.get_version(MSCFramework.TENSORRT), + "linear_to_conv": linear_to_conv, + } + return relax_api.TransformTensorRT(msc_utils.dump_dict(config)) # type: ignore diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc new file mode 100644 index 000000000000..20e4821e6fa7 --- /dev/null +++ b/src/contrib/msc/core/transform/rewrite_utils.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/rewrite_utils.cc + */ +#include "rewrite_utils.h" + +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& expr) { + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); + return builder->Emit(expr, name); +} + +Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, + Attrs attrs) { + const auto& call = Call(op, args, attrs); + return ReEmit(builder, name, call); +} + +Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double value, + const DataType& dtype, size_t ndim) { + const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); + Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); + const auto& constant = Constant(data, NullOpt, span); + if (ndim == 0) { + return constant; + } + static const Op& reshape_op = Op::Get("relax.reshape"); + Array exp_shape(ndim, Integer(1)); + return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)}); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/transform/rewrite_utils.h b/src/contrib/msc/core/transform/rewrite_utils.h new file mode 100644 index 000000000000..2693a6ccd2eb --- /dev/null +++ b/src/contrib/msc/core/transform/rewrite_utils.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/rewrite_utils.h + * \brief Common utilities for rewrite. + */ +#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ +#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ + +#include +#include + +#include + +#include "../../../../relax/transform/utils.h" +#include "../../../../support/scalars.h" +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using Expr = tvm::RelayExpr; +using namespace tvm::relax; + +/*! + * \brief Utils for Layout. + */ +class RewriteUtils { + public: + /*! + * \brief Emit call with span name. + * \return The emitted var. + */ + TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const Expr& expr); + + /*! + * \brief Make and emit a call binding with span. + * \return The emitted var. + */ + TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, + Attrs attrs = Attrs()); + + /*! + * \brief Make and emit a (shaped)constant with span. + * \return The constant/reshape. + */ + TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, double value, + const DataType& dtype, size_t ndim = 0); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index c6e74d42843d..1e846b0b3a61 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const Expr& expr) { - const auto& shape_opt = Downcast(relax::GetStructInfo(expr))->GetShape(); - ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr; +const Array ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) { + const auto& shape_opt = sinfo->GetShape(); + if (!shape_opt.defined()) { + return Array(); + } + if (as_int) { + Array shape; + for (const auto& s : shape_opt.value()) { + shape.push_back(s->IsInstance() ? s : Integer(-1)); + } + return shape; + } return shape_opt.value(); } +const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { + return GetShape(Downcast(relax::GetStructInfo(expr)), as_int); +} + const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(relax::GetStructInfo(expr))->dtype; } diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index d7758cc23d8b..7fb9c87a99f9 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -398,7 +398,9 @@ class ExprUtils { * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const Expr& expr); + TVM_DLL static const Array GetShape(const relax::TensorStructInfo& sinfo, + bool as_int = true); + TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); /*! * \brief Get dtype of expr. diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index a080fdd77862..d90cdc35d17d 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -92,6 +92,8 @@ const String TensorRTOpCode::DType(const DataType& dtype) { dtype_enum = "DataType::kINT8"; } else if (dtype_name == "int32") { dtype_enum = "DataType::kINT32"; + } else if (dtype_name == "int64") { + dtype_enum = "DataType::kINT32"; } else if (dtype_name == "float16") { dtype_enum = "DataType::kHALF"; } else if (dtype_name == "float32") { @@ -267,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call() .op_input_arg() - .func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode())) + .func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(0) .op_dtype_arg(node()->OutputAt(0)->dtype); } @@ -661,7 +663,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - const String& symbol = node()->GetTypeAttr("is_asend") ? "MIN" : "MAX"; + const String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; stack_.op_call() .op_input_arg() .call_arg("TopKOperation::k" + symbol) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 3f85309cd847..542e15d06c3c 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -22,83 +22,101 @@ * \brief Pass for transform the function to tensorrt. */ +#include #include #include #include #include "../../../../relax/transform/utils.h" #include "../../../../support/scalars.h" +#include "../../core/transform/rewrite_utils.h" #include "../../core/utils.h" namespace tvm { namespace relax { using namespace tvm::contrib::msc; -const Array GetShape(const Expr& var) { - const auto& shape_opt = Downcast(GetStructInfo(var))->GetShape(); - ICHECK(shape_opt.defined()) << "Shape is not defined for " << var; - return shape_opt.value(); -} - -Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span, const String& suffix) { - const auto& name = SpanUtils::GetAttr(src_span, msc_attr::kName) + "_" + suffix; - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); - return builder->Emit(expr, name); -} - -Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix, Expr op, - Array args, Attrs attrs = Attrs()) { - const auto& call = Call(op, args, attrs); - return EmitCall(builder, call, src_span, suffix); -} +struct TensorRTTransConfig { + // Whether to cast linear to conv + bool linear_to_conv{true}; + std::vector version{0, 0, 0}; + + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "linear_to_conv") { + reader->Read(&linear_to_conv); + } else if (key == "version") { + reader->Read(&version); + } else { + LOG(FATAL) << "Do not support key " << key; + } + } + } +}; -Expr MakeConstant(double value, const DataType& dtype, const String& name) { - const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); - const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, name); - return Constant(data, NullOpt, span); +const TensorRTTransConfig ParseConfig(const String& config_str) { + TensorRTTransConfig config; + if (config_str.size() > 0) { + std::istringstream is(config_str); + dmlc::JSONReader reader(&is); + reader.Read(&config); + } + return config; } using FRewriteTensorRT = runtime::TypedPackedFunc& new_calls, const Array& version)>; + const Map& new_calls, const String& config)>; + +const Array BroadcastShape(const Array& src_shape, + const Array& out_shape) { + size_t diff = out_shape.size() - src_shape.size(); + Array leading_shape, tailing_shape; + for (size_t i = 0; i < diff; i++) { + leading_shape.push_back(Integer(1)); + } + for (const auto& s : src_shape) { + tailing_shape.push_back(s); + leading_shape.push_back(s); + } + for (size_t i = 0; i < diff; i++) { + tailing_shape.push_back(Integer(1)); + } + if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) { + return tailing_shape; + } + ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) + << "Only support elemwise ops with leading or tailing expand"; + return leading_shape; +} Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = GetShape(call->args[0]); - const auto& shape_b = GetShape(call->args[1]); + const auto& shape_a = ExprUtils::GetShape(call->args[0]); + const auto& shape_b = ExprUtils::GetShape(call->args[1]); + const auto& shape_out = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); if (shape_a.size() > shape_b.size()) { - Array exp_shape(shape_a.size(), Integer(1)); - if (shape_b.size() == 1) { - exp_shape.Set(shape_a.size() - 1, shape_b[0]); - } else if (shape_b.size() == 0) { - LOG_DEBUG << "Expand scalar argument to " << exp_shape; - } else { - LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_b; - } - const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); + const auto& exp_shape = BroadcastShape(shape_b, shape_out); + const auto& expand_b = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); - } - if (shape_a.size() < shape_b.size()) { - Array exp_shape(shape_b.size(), Integer(1)); - if (shape_a.size() == 1) { - exp_shape.Set(shape_b.size() - 1, shape_a[0]); - } else if (shape_a.size() == 0) { - LOG_DEBUG << "Expand scalar argument to " << exp_shape; - } else { - LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_a; - } - const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); + } else if (shape_a.size() < shape_b.size()) { + const auto& exp_shape = BroadcastShape(shape_a, shape_out); + const auto& expand_a = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); } return call; } Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; if (new_calls.count(call->args[0]) && new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { @@ -110,19 +128,20 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, if (conv2d->op != Op::Get("relax.nn.conv2d")) { return call; } - const auto& input_shape = GetShape(call->args[0]); - const auto& bias_shape = GetShape(call->args[1]); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& bias_shape = ExprUtils::GetShape(call->args[1]); const auto* conv_attrs = conv2d->attrs.as(); if (conv_attrs->data_layout == "NCHW") { // expand bias reshape Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); - const auto& exp_bias = MakeCall(builder, call->span, "exp_bias", reshape_op, - {call->args[1], ShapeExpr(exp_bias_shape)}); + const auto& exp_bias = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_bias"), reshape_op, + {call->args[1], ShapeExpr(exp_bias_shape)}); // redirect to conv2d static const Op& add_op = Op::Get("relax.add"); - const auto& exp_add = - MakeCall(builder, call->span, "exp_add", add_op, {reshape->args[0], exp_bias}); + const auto& exp_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_add"), + add_op, {reshape->args[0], exp_bias}); // reduce output return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); @@ -130,48 +149,50 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout; } } - return RewriteElemwise(builder, var, call, new_calls, version); + return RewriteElemwise(builder, var, call, new_calls, config); } Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& out_dtype = Downcast(GetStructInfo(var))->dtype; + const auto& out_dtype = ExprUtils::GetDataType(var); const auto* src_attrs = src_call->attrs.as(); - Expr raw_var; - if (src_attrs->keepdims) { - raw_var = EmitCall(builder, call, call->span, "raw"); - } else { - auto new_attrs = make_object(); - new_attrs->axis = src_attrs->axis; - new_attrs->keepdims = true; - raw_var = - MakeCall(builder, call->span, "keepdims", call->op, {call->args[0]}, Attrs(new_attrs)); + ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) + << "Unexpected out dtype " << out_dtype; + static const Op& topk_op = Op::Get("relax.topk"); + auto topk_attrs = make_object(); + topk_attrs->k = 1; + if (src_attrs->axis.defined()) { + topk_attrs->axis = src_attrs->axis.value()->value; } - static const Op& astype_op = Op::Get("relax.astype"); - auto cast_to_attrs = make_object(); - cast_to_attrs->dtype = DataType::Int(32); - Expr res = MakeCall(builder, call->span, "cast_to", astype_op, {raw_var}, Attrs(cast_to_attrs)); - // reshape back - if (!src_attrs->keepdims) { - const auto& output_shape = GetShape(var); - static const Op& reshape_op = Op::Get("relax.reshape"); - res = MakeCall(builder, call->span, "reshape", reshape_op, {res, ShapeExpr(output_shape)}); + topk_attrs->largest = call->op == Op::Get("relax.argmax"); + topk_attrs->ret_type = "both"; + topk_attrs->dtype = out_dtype; + // change to topk + const auto& topk = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "topk"), topk_op, + {call->args[0]}, Attrs(topk_attrs)); + const auto& get_name = ExprUtils::GetSpanName(call, ".1"); + const auto& get_item = + TupleGetItem(topk, 1, SpanUtils::CreateWithAttr(msc_attr::kName, get_name)); + if (src_attrs->keepdims) { + return get_item; } - auto cast_from_attrs = make_object(); - cast_from_attrs->dtype = out_dtype; - return Call(astype_op, {res}, Attrs(cast_from_attrs), call->sinfo_args, call->span); + const auto& get_item_var = builder->Emit(get_item, get_name); + static const Op& reshape_op = Op::Get("relax.reshape"); + const auto& output_shape = ExprUtils::GetShape(var); + return Call(reshape_op, {get_item_var, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, + call->span); } Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define dims - const auto& in_q_shape = GetShape(call->args[0]); - const auto& in_v_shape = GetShape(call->args[2]); + const auto& in_q_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_v_shape = ExprUtils::GetShape(call->args[2]); const auto& batch_size = in_q_shape[0]; const auto& seq_len = in_q_shape[1]; const auto& num_head = in_q_shape[2]; @@ -198,50 +219,53 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call auto permute_attrs = make_object(); Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; permute_attrs->axes = axes; - const auto& q_trans = MakeCall(builder, call->span, "q_trans", permute_dims_op, {call->args[0]}, - Attrs(permute_attrs)); - const auto& k_trans = MakeCall(builder, call->span, "k_trans", permute_dims_op, {call->args[1]}, - Attrs(permute_attrs)); - const auto& v_trans = MakeCall(builder, call->span, "v_trans", permute_dims_op, {call->args[2]}, - Attrs(permute_attrs)); + const auto& q_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op, + {call->args[0]}, Attrs(permute_attrs)); + const auto& k_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), permute_dims_op, + {call->args[1]}, Attrs(permute_attrs)); + const auto& v_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op, + {call->args[2]}, Attrs(permute_attrs)); Array q_shape({batch_size * num_head, seq_len, head_dim}); - const auto& q_reshape = - MakeCall(builder, call->span, "q_reshape", reshape_op, {q_trans, ShapeExpr(q_shape)}); + const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"), + reshape_op, {q_trans, ShapeExpr(q_shape)}); Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); - const auto& k_reshape = - MakeCall(builder, call->span, "k_reshape", reshape_op, {k_trans, ShapeExpr(k_shape)}); + const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"), + reshape_op, {k_trans, ShapeExpr(k_shape)}); Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); - const auto& v_reshape = - MakeCall(builder, call->span, "v_reshape", reshape_op, {v_trans, ShapeExpr(v_shape)}); + const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"), + reshape_op, {v_trans, ShapeExpr(v_shape)}); auto reduce_permute_attrs = make_object(); Array v_axes{Integer(0), Integer(2), Integer(1)}; reduce_permute_attrs->axes = v_axes; // transpose for batch_matmul - const auto& k_reshape_trans = MakeCall(builder, call->span, "k_reshape_trans", permute_dims_op, - {k_reshape}, Attrs(reduce_permute_attrs)); + const auto& k_reshape_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape_trans"), + permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs)); // calculate product auto matmul_attrs = make_object(); matmul_attrs->out_dtype = in_dtype; - const auto& qk_prod = MakeCall(builder, call->span, "qk_prod", matmul_op, - {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); + const auto& qk_prod = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op, + {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); Expr p_scale; if (src_attrs->scale.defined()) { - const auto& scale = MakeConstant(static_cast(src_attrs->scale.value()->value), in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); - Array exp_shape(3, Integer(1)); - const auto& exp_scale = - MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); - p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod, exp_scale}); + double value = static_cast(src_attrs->scale.value()->value); + const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), + value, in_dtype, 3); + p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), multiply_op, + {qk_prod, scale}); } else { - const auto& scale = - MakeConstant(static_cast(Downcast(head_dim)->value), in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); - Array exp_shape(3, Integer(1)); - const auto& exp_scale = - MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); - const auto& sqrt_scale = MakeCall(builder, call->span, "sqrt_scale", sqrt_op, {exp_scale}); - p_scale = MakeCall(builder, call->span, "p_scale", divide_op, {qk_prod, sqrt_scale}); + double value = static_cast(Downcast(head_dim)->value); + const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), + value, in_dtype, 3); + const auto& sqrt_scale = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale}); + p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), divide_op, + {qk_prod, sqrt_scale}); } // bias @@ -249,12 +273,12 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call if (call->args.size() == 4) { Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; - const auto& prod_exp = - MakeCall(builder, call->span, "prod_exp", reshape_op, {prod, ShapeExpr(exp_shape)}); - const auto& prod_add = - MakeCall(builder, call->span, "prod_add", add_op, {prod_exp, call->args[3]}); - prod = MakeCall(builder, call->span, "prod_reduce", reshape_op, - {prod_add, ShapeExpr(reduce_shape)}); + const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"), + reshape_op, {prod, ShapeExpr(exp_shape)}); + const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"), + add_op, {prod_exp, call->args[3]}); + prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_reduce"), reshape_op, + {prod_add, ShapeExpr(reduce_shape)}); } // causal_mask @@ -262,7 +286,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call if (!src_attrs->causal_mask.defined()) { auto softmax_attrs = make_object(); softmax_attrs->axis = 2; - s_value = MakeCall(builder, call->span, "act", softmax_op, {prod}, Attrs(softmax_attrs)); + s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, + {prod}, Attrs(softmax_attrs)); } else { const auto& causal_mask = src_attrs->causal_mask.value(); PrimValue tril_k; @@ -273,41 +298,47 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call } else { LOG_FATAL << "Unexpected causal_mask " << causal_mask; } - const auto& p_masked = MakeCall(builder, call->span, "p_masked", tril_op, {prod, tril_k}); + const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"), + tril_op, {prod, tril_k}); auto reduce_attrs = make_object(); Array axis{Integer(2)}; reduce_attrs->axis = axis; reduce_attrs->keepdims = true; - const auto& p_max = MakeCall(builder, call->span, "p_max", max_op, {prod}, Attrs(reduce_attrs)); - const auto& p_diff = MakeCall(builder, call->span, "p_diff", subtract_op, {p_masked, p_max}); - const auto& p_exp = MakeCall(builder, call->span, "p_exp", exp_op, {p_diff}); - const auto& p_masked_exp = - MakeCall(builder, call->span, "p_masked_exp", tril_op, {p_exp, tril_k}); + const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"), + max_op, {prod}, Attrs(reduce_attrs)); + const auto& p_diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_diff"), + subtract_op, {p_masked, p_max}); + const auto& p_exp = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), exp_op, {p_diff}); + const auto& p_masked_exp = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, {p_exp, tril_k}); const auto& p_masked_sum = - MakeCall(builder, call->span, "p_masked_sum", sum_op, {p_masked_exp}, Attrs(reduce_attrs)); - s_value = MakeCall(builder, call->span, "act", divide_op, {p_masked_exp, p_masked_sum}); + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked_sum"), sum_op, + {p_masked_exp}, Attrs(reduce_attrs)); + s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), divide_op, + {p_masked_exp, p_masked_sum}); } // final calculation - const auto& o_prod = - MakeCall(builder, call->span, "o_prod", matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); + const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"), + matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); Array o_shape{batch_size, num_head, seq_len, head_dim_v}; return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define expand shape Array exp_shape(input_shape.size(), Integer(1)); exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -318,36 +349,43 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // scale factor: gamma/sqrt(var + epsilon) - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {call->args[4], eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); - const auto& scale_factor = - MakeCall(builder, call->span, "scale_factor", divide_op, {call->args[1], sqrt}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {call->args[4], eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); + const auto& scale_factor = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "scale_factor"), divide_op, {call->args[1], sqrt}); Expr res = call->args[0]; // scale if (src_attrs->scale) { - const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", reshape_op, - {scale_factor, ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_scale}); + const auto& exp_scale = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_scale"), reshape_op, + {scale_factor, ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, + {res, exp_scale}); } // offset if (src_attrs->center) { // offset factor: beta-mean*scale_factor - const auto& average = - MakeCall(builder, call->span, "average", multiply_op, {call->args[3], scale_factor}); + const auto& average = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "average"), + multiply_op, {call->args[3], scale_factor}); const auto& offset_factor = - MakeCall(builder, call->span, "offset_factor", subtract_op, {call->args[2], average}); - const auto& exp_offset = MakeCall(builder, call->span, "exp_offset", reshape_op, - {offset_factor, ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "offset", add_op, {res, exp_offset}); + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset_factor"), subtract_op, + {call->args[2], average}); + const auto& exp_offset = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_offset"), reshape_op, + {offset_factor, ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, + {res, exp_offset}); } return Tuple(Array{res}, call->span); } Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& output_shape = GetShape(var); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(var); Expr concat_input = call->args[0]; static const Op& concat_op = Op::Get("relax.concat"); for (size_t i = 0; i < input_shape.size(); i++) { @@ -357,30 +395,33 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca Array concat_inputs(out_dim / in_dim, concat_input); auto concat_attrs = make_object(); concat_attrs->axis = Integer(i); - concat_input = MakeCall(builder, call->span, "concat_" + std::to_string(i), concat_op, - {Tuple(concat_inputs)}, Attrs(concat_attrs)); + concat_input = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, + {Tuple(concat_inputs)}, Attrs(concat_attrs)); } } return concat_input; } Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto* src_attrs = src_call->attrs.as(); - const auto& input_shape = GetShape(call->args[0]); - const auto& weight_shape = GetShape(call->args[1]); - const auto& output_shape = GetShape(var); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& weight_shape = ExprUtils::GetShape(call->args[1]); + const auto& output_shape = ExprUtils::GetShape(var); if (src_attrs->data_layout == "NCW") { Array new_args; // expand inputs Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), input_shape[2]}; Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), weight_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); - new_args.push_back(MakeCall(builder, call->span, "exp_input", reshape_op, - {call->args[0], ShapeExpr(exp_input_shape)})); - new_args.push_back(MakeCall(builder, call->span, "exp_weight", reshape_op, - {call->args[1], ShapeExpr(exp_weight_shape)})); + new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_input"), + reshape_op, + {call->args[0], ShapeExpr(exp_input_shape)})); + new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), + reshape_op, + {call->args[1], ShapeExpr(exp_weight_shape)})); // change to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); auto conv_attrs = make_object(); @@ -393,8 +434,8 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, conv_attrs->kernel_layout = "OIHW"; conv_attrs->out_layout = "NCHW"; conv_attrs->out_dtype = src_attrs->out_dtype; - const auto& conv2d = - MakeCall(builder, call->span, "exp", conv2d_op, new_args, Attrs(conv_attrs)); + const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp"), + conv2d_op, new_args, Attrs(conv_attrs)); // reduce output return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, call->span); @@ -404,11 +445,80 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, return call; } +Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const String& config) { + // 0.5 * x * (1 + erf(sqrt(0.5) * x)) + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& erf_op = Op::Get("relax.erf"); + + const auto& factor = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "factor"), + std::sqrt(0.5), in_dtype, in_dim); + const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), + multiply_op, {factor, call->args[0]}); + const auto& erf = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "erf"), erf_op, {mul}); + const auto& one = + RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); + const auto& add = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, erf}); + const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), + multiply_op, {call->args[0], add}); + const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, + in_dtype, in_dim); + return Call(multiply_op, {half, mul2}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const String& config) { + // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x))) + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); + + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& pow_op = Op::Get("relax.power"); + static const Op& tanh_op = Op::Get("relax.tanh"); + + const auto& pow_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "pow_factor"), 3, in_dtype, in_dim); + const auto& mul_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "mul_factor"), 0.044715, in_dtype, in_dim); + const auto& pi_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "pi_factor"), std::sqrt(2 / M_PI), in_dtype, in_dim); + + const auto& pow = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "pow"), pow_op, + {call->args[0], pow_factor}); + const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), + multiply_op, {mul_factor, pow}); + const auto& add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, + {mul, call->args[0]}); + const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), + multiply_op, {pi_factor, add}); + const auto& tanh = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "tanh"), tanh_op, {mul2}); + const auto& one = + RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); + const auto& add2 = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, tanh}); + const auto& mul3 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul3"), + multiply_op, {call->args[0], add2}); + const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, + in_dtype, in_dim); + return Call(multiply_op, {half, mul3}, Attrs(), call->sinfo_args, call->span); +} + Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); Array group_shape = input_shape; Array exp_shape(input_shape.size(), Integer(1)); @@ -420,8 +530,8 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(axis, Integer(src_attrs->num_groups)); // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -434,53 +544,63 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // reshape input - const auto& reshape_in = MakeCall(builder, call->span, "reshape_in", reshape_op, - {call->args[0], ShapeExpr(group_shape)}); + const auto& reshape_in = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "reshape_in"), reshape_op, + {call->args[0], ShapeExpr(group_shape)}); // mean(input) auto mean_attrs = make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; - const auto& mean = - MakeCall(builder, call->span, "mean", mean_op, {reshape_in}, Attrs(mean_attrs)); + const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, + {reshape_in}, Attrs(mean_attrs)); // variance: mean((input-mean)*(input-mean)) - const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {reshape_in, mean}); - const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); - const auto& variance = - MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), + subtract_op, {reshape_in, mean}); + const auto& square = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); + const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), + mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = - MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), + reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {variance, exp_eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); // diff/sqrt - Expr res = MakeCall(builder, call->span, "divide", divide_op, {diff, sqrt}); + Expr res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "divide"), divide_op, + {diff, sqrt}); // scale if (src_attrs->scale) { - const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_gamma}); + const auto& exp_gamma = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, + {res, exp_gamma}); } // offset if (src_attrs->center) { - const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "offset", add_op, {res, exp_beta}); + const auto& exp_beta = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, + {res, exp_beta}); } // reshape output return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); Array exp_shape(input_shape.size(), Integer(1)); for (const auto& a : src_attrs->axes) { @@ -488,8 +608,8 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(index, input_shape[index]); } // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -505,30 +625,36 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call auto mean_attrs = make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; - const auto& mean = - MakeCall(builder, call->span, "mean", mean_op, {call->args[0]}, Attrs(mean_attrs)); + const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, + {call->args[0]}, Attrs(mean_attrs)); // variance: mean((input-mean)*(input-mean)) - const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {call->args[0], mean}); - const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); - const auto& variance = - MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), + subtract_op, {call->args[0], mean}); + const auto& square = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); + const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), + mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = - MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), + reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {variance, exp_eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); // diff/sqrt Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args, call->span); // scale if (src_attrs->scale) { - const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - const auto& res_var = EmitCall(builder, res, call->span, "pre_scale"); + const auto& exp_gamma = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + const auto& res_var = + RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_scale"), res); if (src_attrs->center) { res = Call(multiply_op, {res_var, exp_gamma}); } else { @@ -537,87 +663,126 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call } // offset if (src_attrs->center) { - const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - const auto& res_var = EmitCall(builder, res, call->span, "pre_offset"); + const auto& exp_beta = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + const auto& res_var = + RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_offset"), res); res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args, call->span); } return res; } Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { + const auto& trt_config = ParseConfig(config); const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = GetShape(call->args[0]); - const auto& shape_b = GetShape(call->args[1]); + const auto& shape_a = ExprUtils::GetShape(call->args[0]); + const auto& shape_b = ExprUtils::GetShape(call->args[1]); static const Op& reshape_op = Op::Get("relax.reshape"); + if (call->args[1]->IsInstance() && shape_b.size() == 2 && + trt_config.linear_to_conv) { + const auto& out_shape = ExprUtils::GetShape(var); + PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1); + Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; + const auto& exp_in = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_in"), + reshape_op, {call->args[0], ShapeExpr(exp_shape)}); + // transpose and expand weight to OIHW + static const Op& permute_dims_op = Op::Get("relax.permute_dims"); + auto permute_attrs = make_object(); + Array axes{Integer(1), Integer(0)}; + permute_attrs->axes = axes; + const auto& trans_weight = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "trans_weight"), + permute_dims_op, {call->args[1]}, Attrs(permute_attrs)); + Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; + const auto& exp_weight = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), reshape_op, + {trans_weight, ShapeExpr(weight_shape)}); + // to conv2d + static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); + auto conv_attrs = make_object(); + conv_attrs->strides = Array{Integer(1), Integer(1)}; + conv_attrs->padding = Array{Integer(0), Integer(0), Integer(0), Integer(0)}; + conv_attrs->dilation = Array{Integer(1), Integer(1)}; + conv_attrs->groups = 1; + conv_attrs->data_layout = "NCHW"; + conv_attrs->kernel_layout = "OIHW"; + conv_attrs->out_layout = "NCHW"; + conv_attrs->out_dtype = ExprUtils::GetDataType(var); + const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "conv2d"), + conv2d_op, {exp_in, exp_weight}, Attrs(conv_attrs)); + return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), call->sinfo_args, call->span); + } if (shape_a.size() > shape_b.size()) { Array exp_shape(shape_a.size(), Integer(1)); - for (size_t i = shape_b.size(); i < shape_a.size(); i++) { - exp_shape.Set(i, shape_b[i - shape_b.size()]); + size_t diff = shape_a.size() - shape_b.size(); + for (size_t i = diff; i < shape_a.size(); i++) { + exp_shape.Set(i, shape_b[i - diff]); } - const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); + const auto& expand_b = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); } if (shape_a.size() < shape_b.size()) { Array exp_shape(shape_b.size(), Integer(1)); - for (size_t i = shape_a.size(); i < shape_b.size(); i++) { - exp_shape.Set(i, shape_a[i - shape_a.size()]); + size_t diff = shape_b.size() - shape_a.size(); + for (size_t i = diff; i < shape_b.size(); i++) { + exp_shape.Set(i, shape_a[i - diff]); } - const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); + const auto& expand_a = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); } return call; } Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; - Array exp_shape(input_shape.size(), Integer(1)); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); // create 1 constant - const auto& one = - MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, msc_attr::kName) + "_one"); + const auto& one = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), 1, + in_dtype, input_shape.size()); // create ops - static const Op& reshape_op = Op::Get("relax.reshape"); static const Op& divide_op = Op::Get("relax.divide"); static const Op& sqrt_op = Op::Get("relax.sqrt"); // expand and divide - const auto& exp_one = - MakeCall(builder, call->span, "exp_one", reshape_op, {one, ShapeExpr(exp_shape)}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {call->args[0]}); - return Call(divide_op, {exp_one, sqrt}, Attrs(), call->sinfo_args, call->span); + const auto& sqrt = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, + {call->args[0]}); + return Call(divide_op, {one, sqrt}, Attrs(), call->sinfo_args, call->span); } Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; // create ops static const Op& multiply_op = Op::Get("relax.multiply"); static const Op& sigmoid_op = Op::Get("relax.sigmoid"); // silu=input*sigmoid(input) - const auto& sigmoid = MakeCall(builder, call->span, "sigmoid", sigmoid_op, {call->args[0]}); + const auto& sigmoid = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sigmoid"), + sigmoid_op, {call->args[0]}); return Call(multiply_op, {call->args[0], sigmoid}, Attrs(), call->sinfo_args, call->span); } Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& output_shape = GetShape(var); + const auto& output_shape = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size()); std::vector split_begins, split_ends; @@ -646,9 +811,16 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, // create strided_slices Array outputs; for (size_t i = 0; i < split_begins.size(); i++) { - auto slice = strided_slice(call->args[0], Tuple(Array{PrimValue(Integer(axis))}), - Tuple(Array{PrimValue(Integer(split_begins[i]))}), - Tuple(Array{PrimValue(Integer(split_ends[i]))})); + static const Op& strided_slice_op = Op::Get("relax.strided_slice"); + const auto& axes = Tuple(Array{PrimValue(IntImm(DataType::Int(64), axis))}); + const auto& begin = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); + const auto& end = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); + const auto& strides = Tuple(Array{PrimValue(IntImm(DataType::Int(64), 1))}); + auto attrs = make_object(); + attrs->assume_inbound = true; + const auto& slice = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), strided_slice_op, + {call->args[0], axes, begin, end, strides}, Attrs(attrs)); outputs.push_back(slice); } return Tuple(outputs, call->span); @@ -664,6 +836,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm") TVM_REGISTER_OP("relax.nn.conv1d").set_attr("FRewriteTensorRT", RewriteConv1d); TVM_REGISTER_OP("relax.nn.group_norm") .set_attr("FRewriteTensorRT", RewriteGroupNorm); +TVM_REGISTER_OP("relax.nn.gelu").set_attr("FRewriteTensorRT", RewriteGelu); +TVM_REGISTER_OP("relax.nn.gelu_tanh") + .set_attr("FRewriteTensorRT", RewriteGeluTanh); TVM_REGISTER_OP("relax.nn.layer_norm") .set_attr("FRewriteTensorRT", RewriteLayerNorm); TVM_REGISTER_OP("relax.nn.silu").set_attr("FRewriteTensorRT", RewriteSilu); @@ -695,9 +870,9 @@ TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", Re class TensorRTTransformer : public ExprMutator { public: - explicit TensorRTTransformer(IRModule ctx_module, const Array& version) + explicit TensorRTTransformer(IRModule ctx_module, const String& config) : ExprMutator(ctx_module) { - version_ = version; + config_ = config; } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { @@ -707,7 +882,7 @@ class TensorRTTransformer : public ExprMutator { if (rewrite_map.count(op)) { const auto& call = GetRef(call_node); FRewriteTensorRT f = rewrite_map[op]; - const auto& new_call = f(builder_, binding->var, call, new_calls_, version_); + const auto& new_call = f(builder_, binding->var, call, new_calls_, config_); if (new_call != call) { ReEmitBinding(binding, builder_->Normalize(new_call)); new_calls_.Set(binding->var, call); @@ -721,20 +896,19 @@ class TensorRTTransformer : public ExprMutator { private: Map new_calls_; - Array version_; + String config_; }; -Function TransformTensorRT(const Function& func, const IRModule& module, - const Array& version) { - return Downcast(TensorRTTransformer(module, version).VisitExpr(func)); +Function TransformTensorRT(const Function& func, const IRModule& module, const String& config) { + return Downcast(TensorRTTransformer(module, config).VisitExpr(func)); } namespace transform { -Pass TransformTensorRT(const Array& version) { +Pass TransformTensorRT(const String& config) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return relax::TransformTensorRT(f, m, version); + return relax::TransformTensorRT(f, m, config); }; return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 74c25ceacfe8..7c8c2830995c 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -87,7 +87,7 @@ def _is_target_func(func): NameChecker().check(func) -def verify_model(torch_model, input_info, allow_incomplete=False): +def verify_model(torch_model, input_info, **trans_config): """Build model and verify results""" graph_model = fx.symbolic_trace(torch_model) @@ -100,9 +100,7 @@ def verify_model(torch_model, input_info, allow_incomplete=False): golden = [golden] golden = [g.detach().cpu().numpy() for g in golden] # partition module for tensorrt - mod, graphs, weights = translate.partition_for_tensorrt( - mod, trans_config={"allow_incomplete": allow_incomplete} - ) + mod, graphs, weights = translate.partition_for_tensorrt(mod, trans_config=trans_config) check_names(mod) output_folder = msc_utils.msc_dir() # tranalte to tensorrt @@ -191,6 +189,8 @@ def forward(self, x, y): input_info = [([1, 3, 10, 10], "float32")] verify_model(Dense1(), input_info) verify_model(Dense2(), input_info) + verify_model(Dense1(), input_info, linear_to_conv=True) + verify_model(Dense2(), input_info, linear_to_conv=True) verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) @@ -368,10 +368,10 @@ def __init__(self): self.embedding = torch.nn.Embedding(10, 3) def forward(self, data): - return self.embedding(data) + return self.embedding(data.to(torch.int64)) - verify_model(Embedding(), [([4], "int64")], allow_incomplete=True) - verify_model(Embedding(), [([4, 5], "int64")], allow_incomplete=True) + verify_model(Embedding(), [([4], "int32")]) + verify_model(Embedding(), [([4, 5], "int32")]) @requires_tensorrt @@ -801,14 +801,14 @@ def test_argmax(): class Argmax1(Module): def forward(self, data): - return torch.argmax(data, dim=-1) + return torch.argmax(data, dim=-1).to(torch.int32) class Argmax2(Module): def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True) + return torch.argmax(data, dim=-1, keepdim=True).to(torch.int32) - verify_model(Argmax1(), [([256, 256], "float32")], allow_incomplete=True) - verify_model(Argmax2(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmax1(), [([256, 256], "float32")]) + verify_model(Argmax2(), [([256, 256], "float32")]) @requires_tensorrt @@ -817,14 +817,14 @@ def test_argmin(): class Argmin1(Module): def forward(self, data): - return torch.argmin(data, dim=-1) + return torch.argmin(data, dim=-1).to(torch.int32) class Argmin2(Module): def forward(self, data): - return torch.argmin(data, dim=-1, keepdim=True) + return torch.argmin(data, dim=-1, keepdim=True).to(torch.int32) - verify_model(Argmin1(), [([256, 256], "float32")], allow_incomplete=True) - verify_model(Argmin2(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmin1(), [([256, 256], "float32")]) + verify_model(Argmin2(), [([256, 256], "float32")]) @requires_tensorrt @@ -876,5 +876,22 @@ def forward(self, x, y): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) +@requires_tensorrt +def test_gelu(): + """test tensorrt translator for gelu""" + + class Gelu1(Module): + def forward(self, data): + return torch.nn.functional.gelu(data) + + class Gelu2(Module): + def forward(self, data): + return torch.nn.functional.gelu(data, approximate="tanh") + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Gelu1(), input_info) + verify_model(Gelu2(), input_info) + + if __name__ == "__main__": tvm.testing.main() From 995524a84276869c14a231a84f66d56fca3afe73 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 8 Sep 2024 05:41:48 -0500 Subject: [PATCH 122/202] [Relax] Refactor RealizeVDevice to remove in-place mutation (#17213) * [Relax] Refactor RealizeVDevice to remove in-place mutation Prior to this commit, the `relax.transform.RealizeVDevice` pass performed in-place update on expressions appearing in its input `IRModule`, overwriting their struct info. In-place mutation of TVM's IR types is only legal when the scope has sole ownership of the IR object, such as through the `CopyOnWrite` functionality, and is not allowed when the object is shared. As a result, applying `RealizeVDevice` would cause unexpected updates in unrelated expressions. Most noticeably, the `IRModule` used as input to `RealizeVDevice` would have its variable erroneously updated. This commit refactors the `RealizeVDevice` transform to remove all in-place mutation. The same propagation rules are followed, with known `VDevice` annotations propagated forward from the output of `R.hint_on_device`, and propagated backwards from the input of `R.hint_on_device` if no such annotation already exists. Closes https://github.com/apache/tvm/issues/17205. * lint fixes --- src/relax/transform/realize_vdevice.cc | 492 +++++++++++------- .../relax/test_transform_realize_vdevice.py | 80 +++ 2 files changed, 389 insertions(+), 183 deletions(-) diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index ec02efa996e6..0df86515dbcc 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -29,259 +29,385 @@ namespace tvm { namespace relax { -void UpdateTensorStructInfo(Expr expr, StructInfo struct_info) { - if (auto* tensor_sinfo = expr->struct_info_.as()) { - auto* new_tensor_sinfo = struct_info.as(); - if (new_tensor_sinfo != nullptr && new_tensor_sinfo->vdevice.defined() && - !tensor_sinfo->vdevice.defined()) { - expr->struct_info_ = struct_info; - expr->checked_type_ = GetStaticType(struct_info); - } +namespace { + +class VDeviceLookup { + public: + explicit VDeviceLookup(IRModule mod) { + auto opt_global_info = mod->global_infos.Get("vdevice"); + if (!opt_global_info) return; + + auto downcast_vdevice = [](GlobalInfo info) -> VDevice { + if (auto vdevice = info.as()) { + return vdevice.value(); + } else { + LOG(FATAL) << "TypeError: " + << "Each item in an IRModule's \"vdevice\" annotation must be a VDevice, " + << "but instead found item of type " << info->GetTypeKey(); + } + }; + + opt_vdevices_ = opt_global_info.value().Map(downcast_vdevice); } -} -void AddVDeviceToStuctInfo(Expr expr, VDevice vdevice) { - auto* tinfo = GetStructInfoAs(expr); - if (tinfo != nullptr) { - if (tinfo->shape.defined()) { - UpdateTensorStructInfo( - expr, TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span)); - } else { - UpdateTensorStructInfo(expr, - TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span)); + VDevice operator()(Attrs hint_on_device_attrs) { + auto attrs = hint_on_device_attrs.as(); + ICHECK(attrs); + int32_t device_type = attrs->dev_type; + int32_t device_id = attrs->dev_id; + + CHECK(opt_vdevices_.defined()) + << "ValueError: The target VDevice in the GlobalInfos was not found."; + + auto vdevices = opt_vdevices_.value(); + CHECK_GE(device_id, 0) << "ValueError: " + << "The device id in R.hint_on_device must not be negative"; + + for (auto vdevice : vdevices) { + int dev_type = vdevice->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevice->vdevice_id == device_id) { + return vdevice; + } } + LOG(FATAL) << "ValueError: " + << "Expected to find device with type " << device_id << " and id " << device_id + << ", but no such device was found in the IRModule's \"vdevice\" annotation"; } -} -class VDeviceRealizer : public ExprMutator { + private: + Optional> opt_vdevices_ = NullOpt; +}; + +class DeviceHintCollector : ExprVisitor { public: - explicit VDeviceRealizer(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {} + static std::tuple, Map> Collect(IRModule mod) { + DeviceHintCollector visitor{VDeviceLookup(mod)}; - IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - auto updated_func = Downcast(this->VisitExpr(func)); - builder_->UpdateFunction(gv, Downcast(updated_func)); + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor(func.value()); } } - return builder_->GetContextIRModule(); + + return {visitor.known_vdevice_, visitor.hint_on_device_inputs_}; } private: - using ExprMutator::VisitExpr_; + explicit DeviceHintCollector(VDeviceLookup vdevice_lookup) : vdevice_lookup_(vdevice_lookup) {} + + void VisitExpr_(const FunctionNode* func) override { + ExprVisitor::VisitExpr_(func); + + std::function check_ret_sinfo = [this, &check_ret_sinfo]( + Expr expr, StructInfo sinfo) { + // If the function is annotated as returning a tensor on a + // specific device, then that annotation may be propagated into + // the returned variable. + if (auto tensor_info = sinfo.as(); + tensor_info && tensor_info->vdevice.defined()) { + if (auto opt_var = expr.as()) { + auto var = opt_var.value(); + if (!known_vdevice_.count(var)) { + known_vdevice_.Set(var, tensor_info->vdevice.value()); + } + } + } - void AddToVDeviceMap(Expr expr, VDevice vdevice) { - ICHECK((vdevice_map_.count(expr) == 0) || (vdevice_map_[expr] == vdevice)) - << "Conflicted vdevice found."; - vdevice_map_.Set(expr, vdevice); + // If the function is annotated as returning a tuple of tensors, + // where some elements of the tuple are tensors that exist on a + // specific device, then those annotations may be propagated + // into the corresponding tensor annotations. + if (auto tuple_info = sinfo.as()) { + // The returned tuple is not necessarily an in-line tuple. In + // order to find the variables that are bound to the + // individual tuple elements, we may need to unwrap the + // variable bindings in order to find the tuple itself. This + // unwrapping is not required for the tensor case, as it would + // already be handled when propagating VDevice across variable + // definitions. + while (auto bound_value = LookupBinding(expr)) { + expr = bound_value.value(); + } + + // Even after unwrapping variable bindings, the resulting + // expression is not required to be a tuple literal. For + // example, the function may return one of its arguments as an + // output, or may return the result of a `relax::Call` that + // produces a tuple of outputs. + if (auto tuple = expr.as()) { + CHECK_EQ(tuple_info->fields.size(), tuple->fields.size()) + << "ValueError: " + << "Function returns a tuple with " << tuple->fields.size() << " elements, " + << "but is annotated as returning a tuple with " << tuple_info->fields.size() + << " elements"; + for (size_t i = 0; i < tuple->fields.size(); i++) { + check_ret_sinfo(tuple->fields[i], tuple_info->fields[i]); + } + } + } + }; + + check_ret_sinfo(func->body->body, func->ret_struct_info); } - Expr VisitExpr(const Expr& expr) { - auto visited_expr = ExprMutator::VisitExpr(expr); - if (vdevice_map_.count(visited_expr)) { - AddVDeviceToStuctInfo(visited_expr, vdevice_map_[visited_expr]); + void VisitVarDef(const Var& var) override { + if (auto tinfo = var->struct_info_.as(); + tinfo && tinfo->vdevice.defined()) { + known_vdevice_.Set(var, tinfo->vdevice.value()); } - return visited_expr; + ExprVisitor::VisitVarDef(var); } - Expr VisitExpr_(const FunctionNode* op) final { - Function func = GetRef(op); - auto* finfo = GetStructInfoAs(func); - if (finfo != nullptr) { - StructInfo ret = finfo->ret; - auto* tinfo = finfo->ret.as(); - if (tinfo != nullptr && tinfo->vdevice.defined()) { - AddToVDeviceMap(op->body, tinfo->vdevice.value()); - } - } - Function visited_func = Downcast(this->VisitExprPostOrder_(op)); - return visited_func; + void VisitBinding(const Binding& binding) override { + ExprVisitor::VisitBinding(binding); + binding_lookup_.Set(binding->var, GetBoundValue(binding)); } - Expr VisitExpr_(const SeqExprNode* op) final { - SeqExpr seq_expr = GetRef(op); - if (vdevice_map_.count(seq_expr)) { - AddToVDeviceMap(seq_expr->body, vdevice_map_[seq_expr]); + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) override { + ExprVisitor::VisitBinding_(binding, call); + if (call->op == hint_on_device_op_) { + auto vdevice = vdevice_lookup_(call->attrs); + known_vdevice_.Set(binding->var, vdevice); + + ICHECK_EQ(call->args.size(), 1); + if (auto arg_var = call->args[0].as()) { + hint_on_device_inputs_.Set(arg_var.value(), vdevice); + } } - SeqExpr visited_seqexpr = Downcast(this->VisitExprPostOrder_(op)); - return visited_seqexpr; } - BindingBlock VisitBindingBlock_(const BindingBlockNode* block) { - builder_->BeginBindingBlock(); - for (size_t i = block->bindings.size(); i > 0; --i) { - this->VisitBinding(block->bindings[i - 1]); - } - for (size_t i = bindings_.size(); i > 0; --i) { - builder_->EmitNormalized(bindings_[i - 1]); + Optional LookupBinding(const Expr& expr) const { + if (auto var = expr.as()) { + if (auto bound = binding_lookup_.Get(var.value())) { + return bound.value(); + } } - bindings_.clear(); - return builder_->EndBlock(); + return NullOpt; } - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { - builder_->BeginDataflowBlock(); - for (size_t i = block->bindings.size(); i > 0; --i) { - this->VisitBinding(block->bindings[i - 1]); - } - for (size_t i = bindings_.size(); i > 0; --i) { - builder_->EmitNormalized(bindings_[i - 1]); + // A lookup to identify the VDevice from the IRModule attributes, + // given the device type and device id from the R.hint_on_device + // attributes. + VDeviceLookup vdevice_lookup_; + + // A lookup of variable bindings, used to unwrap the variable + // bindings in functions that return a tuple. + Map binding_lookup_; + + // A map from Var to the VDevice they are known to occur on. This + // only contains variables whose location is explicitly known + // (e.g. output of `R.hint_on_device`, variables with explicit + // `VDevice` in their struct info), and does not include variables + // whose location is (e.g. input of `R.hint_on_device`). + Map known_vdevice_; + + // A map from Var to the VDevice they are expected to occur on. If + // a variable appears in both `known_vdevice_` and + // `hint_on_device_inputs_`, then `known_vdevice_` takes priority. + // + // For example, `B = R.hint_on_device(A, tvm.cuda(0))` implies that + // `B` must be located on "cuda:0". However, `A` may already have a + // `VDevice` annotation, or may be the output of `R.to_device`. + // Therefore, we only determine that `A` is located on "cuda:0" if + // no other annotation has already provided a known location for + // `A`. + Map hint_on_device_inputs_; + + // The `R.hint_on_device` operator. + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); +}; + +// Utility to determine which Var instances must be located on the +// same VDevice. +class VDeviceSetCollector : ExprVisitor { + public: + static Map> Collect(IRModule mod) { + VDeviceSetCollector visitor; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor(func.value()); + } } - bindings_.clear(); - return builder_->EndBlock(); + return visitor.var_to_co_located_vars_; } - void VisitBinding_(const VarBindingNode* binding) { - if (vdevice_map_.count(binding->var)) { - AddToVDeviceMap(binding->value, vdevice_map_[binding->var]); - AddVDeviceToStuctInfo(binding->var, vdevice_map_[binding->var]); - } - auto* tinfo = GetStructInfoAs(binding->var); - if (tinfo != nullptr && tinfo->vdevice.defined()) { - AddToVDeviceMap(binding->value, tinfo->vdevice.value()); - } - UpdateTensorStructInfo(binding->value, GetStructInfo(binding->var)); - Expr new_value = this->VisitExpr(binding->value); - if (!binding->var->struct_info_.defined()) { - UpdateTensorStructInfo(binding->var, GetStructInfo(new_value)); - } + private: + void VisitBinding(const Binding& binding) override { + auto cached = current_binding_; + current_binding_ = binding->var; + ExprVisitor::VisitBinding(binding); + current_binding_ = cached; + } - if (new_value.same_as(binding->value)) { - bindings_.push_back(GetRef(binding)); - } else { - bindings_.push_back(VarBinding(binding->var, new_value)); + void VisitExpr_(const CallNode* call) override { + if (call->op != to_vdevice_op_ && call->op != hint_on_device_op_) { + ExprVisitor::VisitExpr_(call); } } - Expr VisitExpr_(const CallNode* call) final { - // Record the vdevice information of each arguments of call - if (auto* sinfo = call->struct_info_.as()) { - if (sinfo->vdevice.defined() && call->op != to_vdevice_op_) { - Array call_args; - for (Expr arg : call->args) { - AddToVDeviceMap(arg, sinfo->vdevice.value()); - } - } + void VisitExpr_(const VarNode* op) override { + if (current_binding_) { + auto var = GetRef(op); + var_to_co_located_vars_[current_binding_.value()].push_back(var); + var_to_co_located_vars_[var].push_back(current_binding_.value()); } - return Downcast(ExprMutator::VisitExpr_(call)); } - /*! \brief The context IRModule. */ - IRModule mod_; - /*! \brief The bindings in reverse ordering. */ - Array bindings_; - /*! \brief The virtual device map. */ - Map vdevice_map_; + Optional current_binding_ = NullOpt; + + // Lookup from relax variable to the set of relax variables which + // must be located on the same device. For example, a trivial + // binding `B = A` implies that both `B` and `A` are on the same + // device. Similarly, `C = R.add(A,B)` implies that `A`, `B`, and + // `C` are all on the same device. + // + // In general, variables that are used as part of the same + // `relax::Call` operation must be located on the same device, with + // the exception of `R.hint_on_device` and `R.to_vdevice`, which may + // introduce a transfer across devices. + std::unordered_map> var_to_co_located_vars_; + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; -class HintOnDeviceRemover : public ExprMutator { - public: - explicit HintOnDeviceRemover(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {} +Map InferVDevice(IRModule mod) { + auto [explicit_annotations, hint_on_device_args] = DeviceHintCollector::Collect(mod); + + auto co_located_var_lookup = VDeviceSetCollector::Collect(mod); + + Map known_vdevice; + std::vector to_visit; + + // A helper function to propagate all `known_vdevice` entries based + // on the connections in `co_located_var_lookup`. + auto propagate = [&]() { + while (to_visit.size()) { + Var visiting = to_visit.back(); + to_visit.pop_back(); - IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - auto updated_func = Downcast(this->VisitExpr(func)); - builder_->UpdateFunction(gv, Downcast(updated_func)); + if (auto upstream_vars = co_located_var_lookup.Get(visiting)) { + auto vdevice = known_vdevice.at(visiting); + for (Var upstream_var : upstream_vars.value()) { + if (!known_vdevice.count(upstream_var)) { + known_vdevice.Set(upstream_var, vdevice); + to_visit.push_back(upstream_var); + } + } } } - return builder_->GetContextIRModule(); + }; + + // First round, mark variables whose vdevice is explicitly known + // (e.g. the output of R.hint_on_device), and propagate. + for (const auto& [var, vdevice] : explicit_annotations) { + to_visit.push_back(var); + known_vdevice.Set(var, vdevice); + } + propagate(); + + // Second round, mark variables whose vdevice is hinted at (e.g. the + // input of R.hint_on_device), and propagate. + for (const auto& [var, vdevice] : hint_on_device_args) { + if (!known_vdevice.count(var)) { + to_visit.push_back(var); + known_vdevice.Set(var, vdevice); + } } + propagate(); - private: - using ExprMutator::VisitExpr_; + return known_vdevice; +} - void AddToVDeviceMap(Expr expr, VDevice vdevice) { - ICHECK((vdevice_map_.count(expr) == 0) || (vdevice_map_[expr] == vdevice)) - << "Conflicted vdevice found."; - vdevice_map_.Set(expr, vdevice); - } +// Update the module to include the inferred VDevice annotations. +class VDeviceStructInfoUpdater : ExprMutator { + public: + static IRModule Apply(IRModule mod, Map vdevice_map) { + VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map); - VDevice LookupVDevice(int32_t device_type, int32_t device_id) { - Array vdevices = mod_->global_infos["vdevice"]; - if (vdevices.empty() || device_id < 0 || static_cast(device_id) >= vdevices.size()) { - LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found."; - } - for (auto vdev : vdevices) { - auto vdevice = Downcast(vdev); - int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { - return vdevice; + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto updated = Downcast(mutator(func.value())); + if (!updated.same_as(base_func)) { + updates->Add(gvar, updated); + } } } - LOG(WARNING) << "The specified device was not found in the global_infos"; - return VDevice(); - } - Expr VisitExpr(const Expr& expr) { - auto visited_expr = ExprMutator::VisitExpr(expr); - if (vdevice_map_.count(visited_expr)) { - AddVDeviceToStuctInfo(visited_expr, vdevice_map_[visited_expr]); + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); } - return visited_expr; - } - void VisitBinding_(const VarBindingNode* binding) { - Expr new_value = this->VisitExpr(binding->value); - UpdateTensorStructInfo(binding->var, GetStructInfo(new_value)); - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); - } else { - builder_->EmitNormalized(VarBinding(binding->var, new_value)); - } + return mod; } - Expr VisitExpr_(const CallNode* call) final { - // Replace hint_on_device with to_vdevice - if (call->op == hint_on_device_op_) { - // Find out the vdevice from global_infos - Expr data = call->args[0]; - auto attrs = call->attrs.as(); - int32_t device_type = attrs->dev_type; - int32_t device_id = attrs->dev_id; - VDevice dst_vdev = LookupVDevice(device_type, device_id); - // Insert to_vdevice if input are on different device - auto* tinfo = GetStructInfoAs(data); - if (tinfo != nullptr) { - if (!tinfo->vdevice.defined()) { - // Remove hint_on_device - AddVDeviceToStuctInfo(data, dst_vdev); - AddToVDeviceMap(data, dst_vdev); - return data; - } else if (tinfo->vdevice.value() != dst_vdev) { - // Call to_vdevice - ObjectPtr attrs = make_object(); - attrs->dst_vdevice = dst_vdev; - auto new_call = Call(to_vdevice_op_, {data}, Attrs(attrs), {}); - AddToVDeviceMap(new_call, dst_vdev); - return new_call; + private: + VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, Map vdevice_map) + : vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {} + + Var VisitVarDef(const Var& old_var) override { + auto var = ExprMutator::VisitVarDef(old_var); + if (auto tinfo = var->struct_info_.as()) { + if (auto opt = vdevice_map_.Get(old_var)) { + auto vdevice = opt.value(); + TensorStructInfo new_sinfo = [&]() { + if (tinfo->shape.defined()) { + return TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span); + } else { + return TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span); + } + }(); + + if (var->IsInstance()) { + var = DataflowVar(var->vid, new_sinfo, var->span); + } else { + var = Var(var->vid, new_sinfo, var->span); } } } - auto visited_call = ExprMutator::VisitExpr_(call); - visited_call->struct_info_ = NullOpt; - return builder_->Normalize(visited_call); + return var; } - /*! \brief The context IRModule. */ - IRModule mod_; - /*! \brief The virtual device map. */ - Map vdevice_map_; + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) override { + auto call = Downcast(ExprMutator::VisitExpr_(op)); + + if (call->op != hint_on_device_op_) { + return call; + } + + ICHECK_EQ(call->args.size(), 1); + auto arg = call->args[0]; + auto input_vdevice = Downcast(arg->struct_info_)->vdevice; + auto output_vdevice = vdevice_lookup_(call->attrs); + + if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) { + return arg; + } else { + ObjectPtr attrs = make_object(); + attrs->dst_vdevice = output_vdevice; + return Call(to_vdevice_op_, {arg}, Attrs(attrs), {}); + } + } + VDeviceLookup vdevice_lookup_; + Map vdevice_map_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; +} // namespace namespace transform { Pass RealizeVDevice() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - IRModule new_mod = HintOnDeviceRemover(mod).Run(); - return VDeviceRealizer(new_mod).Run(); + auto known_vdevices = InferVDevice(mod); + return VDeviceStructInfoUpdater::Apply(mod, known_vdevices); }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index f8d99eb3b59f..4c530d5e4931 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" + import tvm import tvm.testing from tvm.ir import VDevice @@ -202,6 +203,56 @@ def foo( verify(Input, Expect) +def test_tuple_func_ret(): + @I.ir_module + class Input: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda"), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + z: R.Tensor((2, 3), "float32"), + ) -> R.Tuple([R.Tensor((2, 3), "float32", "cuda"), R.Tensor((2, 3), "float32", "cuda")]): + with R.dataflow(): + lv0 = R.add(x, y) + gv = R.multiply(lv0, z) + R.output(gv) + return (gv, gv) + + @I.ir_module + class Expect: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda"), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32", "cuda"), + y: R.Tensor((2, 3), "float32", "cuda"), + z: R.Tensor((2, 3), "float32", "cuda"), + ) -> R.Tuple([R.Tensor((2, 3), "float32", "cuda"), R.Tensor((2, 3), "float32", "cuda")]): + with R.dataflow(): + lv0: R.Tensor((2, 3), "float32", "cuda") = R.add(x, y) + gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv0, z) + R.output(gv) + return (gv, gv) + + verify(Input, Expect) + + def test_multi_device(): @I.ir_module class Input: @@ -326,5 +377,34 @@ def foo( verify(Input, Expect) +def test_input_module_is_unmodified(): + def make_module(): + @I.ir_module + class Module: + I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + z: R.Tensor((2, 3), "float32"), + ) -> R.Tensor((2, 3), "float32"): + x1 = x + y1 = y + x2 = x1 + y2 = y1 + s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) + m = R.multiply(s, z) + return m + + return Module + + original = make_module() + expected = make_module() + + RealizeVDevice()(original) + tvm.ir.assert_structural_equal(original, expected) + + if __name__ == "__main__": tvm.testing.main() From e468426bfd43fadb555ef0e561b9047a5d89852e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 8 Sep 2024 06:42:06 -0400 Subject: [PATCH 123/202] [Fix][Relax] Add the missing tree-attn func arg for KV cache creation (#17345) This PR fixes the TIRPagedKVCache construction issue, which is caused by missing the tree-attention with paged KV cache kernel. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 7b14c67a2e57..ae0537f0d9af 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -375,6 +375,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, # fmt: on # pylint: enable=line-too-long From 35fdf8b16c3cad396dc2d21efe2bc0fc871a2285 Mon Sep 17 00:00:00 2001 From: Krishna Bindumadhavan <31140965+f2013519@users.noreply.github.com> Date: Mon, 9 Sep 2024 00:33:12 +0530 Subject: [PATCH 124/202] [relay][qnn]: Fix qnn.avg_pool2d layout inference (#17339) --- src/relay/qnn/op/avg_pool2d.cc | 8 +- .../relay/test_pass_convert_op_layout.py | 79 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/relay/qnn/op/avg_pool2d.cc b/src/relay/qnn/op/avg_pool2d.cc index b2dc08b85686..e1a28169ccda 100644 --- a/src/relay/qnn/op/avg_pool2d.cc +++ b/src/relay/qnn/op/avg_pool2d.cc @@ -132,9 +132,11 @@ InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs, auto avgpool_new_layouts = PoolInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); - // Scales and zero points are scalars, use the "undef" layout for them. - Array input_layouts = {avgpool_new_layouts->input_layouts[0], Layout::Undef(), - Layout::Undef(), Layout::Undef(), Layout::Undef()}; + // Scales and zero points are scalars, the layouts of these tensors can be treated as channel + // layout. + Layout channel_layout = Layout("C"); + Array input_layouts = {avgpool_new_layouts->input_layouts[0], channel_layout, + channel_layout, channel_layout, channel_layout}; Array output_layouts = avgpool_new_layouts->output_layouts; return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs); } diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 49afe492a121..5450f1aa6906 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1542,6 +1542,85 @@ def expected(): tvm.ir.assert_structural_equal(a, b) +def test_qnn_conv_avgpool_2d_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + y = relay.qnn.op.conv2d( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.cast(y, "int8") + y = relay.qnn.op.avg_pool2d( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + layout="NHWC", + out_layout="NHWC", + pool_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.qnn.op.conv2d( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.cast(y, "int8") + y = relay.qnn.op.avg_pool2d( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + layout="NCHW", + out_layout="NCHW", + pool_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"], "qnn.avg_pool2d": ["NCHW"]}) + ) + b = run_opt_pass(expected(), transform.InferType()) + + tvm.ir.assert_structural_equal(a, b) + + def test_conv_roi_align_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56)) From f02d295e0b38f48efebedcdb62bd82ffa17ef15e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 9 Sep 2024 17:55:50 -0700 Subject: [PATCH 125/202] [CI] Upgrade github upload-artifact action (#17355) --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 759acd1fa506..db2d870da9bd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -175,7 +175,7 @@ jobs: export PATH="${ANDROID_NDK_LATEST_HOME}:$PATH" gradle clean build - name: Upload android_rpc APK - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: android_rpc-debug.apk path: ./apps/android_rpc/app/build/outputs/apk/debug/app-debug.apk @@ -186,7 +186,7 @@ jobs: export PATH="${ANDROID_NDK_LATEST_HOME}:$PATH" gradle clean build - name: Upload android_deploy APK - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: android_deploy-debug.apk path: ./apps/android_deploy/app/build/outputs/apk/debug/app-debug.apk From d7e0af2d88f75e2ab21c6dbde43813a033c0fb35 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Tue, 10 Sep 2024 10:52:35 +0300 Subject: [PATCH 126/202] [LLVM][RUNTIME] Fix RISC-V CodeModel propagation to ORCJIT runtime executor (#17347) --- src/target/llvm/llvm_instance.h | 10 ++++++++++ src/target/llvm/llvm_module.cc | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index fd63140a0b37..add2af6002c6 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -215,6 +215,16 @@ class LLVMTargetInfo { * \return `llvm::TargetOptions` object for this target */ const llvm::TargetOptions& GetTargetOptions() const { return target_options_; } + /*! + * \brief Get the LLVM target reloc model + * \return `llvm::Reloc::Model` object for this target + */ + const llvm::Reloc::Model& GetTargetRelocModel() const { return reloc_model_; } + /*! + * \brief Get the LLVM target code model + * \return `llvm::CodeModel::Model` object for this target + */ + const llvm::CodeModel::Model& GetTargetCodeModel() const { return code_model_; } /*! * \brief Get fast math flags * \return `llvm::FastMathFlags` for this target diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index baa68feedfa2..34bbb6a0c6a9 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -482,6 +482,14 @@ void LLVMModuleNode::InitORCJIT() { tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive); #endif + // Default is no explicit JIT code & reloc model + // Propagate instance code & reloc for RISCV case. + auto arch = tm_builder.getTargetTriple().getArch(); + if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) { + tm_builder.setRelocationModel(llvm_target->GetTargetRelocModel()); + tm_builder.setCodeModel(llvm_target->GetTargetCodeModel()); + } + // create the taget machine std::unique_ptr tm = llvm::cantFail(tm_builder.createTargetMachine()); if (!IsCompatibleWithHost(tm.get())) { From ec42883b1efd5016f32b0da8fc6cbbf72a1ce7f4 Mon Sep 17 00:00:00 2001 From: Viranchee Lotia Date: Tue, 10 Sep 2024 12:45:27 -0400 Subject: [PATCH 127/202] [Docs] TVM pip Installation fix (#17352) * TVM pip Installation fix After successfully building tvm on Apple Silicon, I wasn't able to get `pip install` working. It did not find `libtvm.dylib`. Specifying TVM_LIBRARY_PATH seems to fix the issue * Fix lint error + fix naming convention --- docs/install/from_source.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index a963d06ab559..8e2d94db5f9a 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -145,8 +145,8 @@ Leaving the build environment ``tvm-build-venv``, there are two ways to install conda activate your-own-env conda install python # make sure python is installed - cd /path-to-tvm/python - pip install -e . + export TVM_LIBRARY_PATH=/path-to-tvm/build + pip install -e /path-to-tvm/python Step 4. Validate Installation ----------------------------- From cc533b925452bcaaed9a1ca09da8bcb7e9e30622 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 10 Sep 2024 16:31:17 -0700 Subject: [PATCH 128/202] [Relax] Fix inline source module cause path too long error (#17354) When the source is provided as inline string literal, creating `Path` object causes path too long error. --- python/tvm/relax/frontend/nn/extern.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 332d07cbc3c5..198ef0f23c46 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -228,7 +228,10 @@ def _detect_source_code(source_code) -> str: path = Path(source_code) except: # pylint: disable=bare-except return source_code - if not path.is_file(): + try: + if not path.is_file(): + return source_code + except: # pylint: disable=bare-except return source_code with path.open("r", encoding="utf-8") as file: return file.read() From f52143e6c822b04791961bcdfbf965f5eb1674d2 Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Wed, 11 Sep 2024 11:41:40 +0800 Subject: [PATCH 129/202] [Relax][Frontend][Onnx] fix params name bug in onnx frontend (#17350) * fix params name bug * add test_multi_ops_with_same_params and test_params_names_start_with_onnx --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 4 +- tests/python/relax/test_frontend_onnx.py | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index c3116f9988ce..462d1cf92c01 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -91,7 +91,7 @@ def get_constant( # Convert if possible if isinstance(var, relax.Var) and var.name_hint in params: # When converting a parameter to a constant, update references to it as well. - _, value = params.pop(var.name_hint) + _, value = params[var.name_hint] const_value = relax.const(value) graph_nodes[var.name_hint] = const_value return const_value @@ -2152,7 +2152,7 @@ def _parse_graph_initializers(self, graph: onnx.onnx_ml_pb2.GraphProto): init_var = self._new_var(var_name, shape=array.shape, dtype=array.dtype) self._nodes[init_tensor.name] = init_var # We need to keep track of both the real value and variable for this variable. - self._params[init_tensor.name] = (init_var, array) + self._params[var_name] = (init_var, array) # Otherwise we can use the weight as a constant. else: self._nodes[init_tensor.name] = relax.const(array) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3ea987973578..8f4e9881f497 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1909,5 +1909,48 @@ def test_multi_inputs_with_same_symbolic_shape(): check_correctness(model) +def test_multi_ops_with_same_params(): + reshape_node_1 = helper.make_node("Reshape", ["a", "x"], ["b"]) + reshape_node_2 = helper.make_node("Reshape", ["b", "x"], ["c"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node_1, reshape_node_2], + "test_multi_ops_with_same_params", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_multi_ops_with_same_params") + check_correctness(model) + + +def test_params_names_start_with_onnx(): + reshape_node = helper.make_node("Reshape", ["a", "onnx::x"], ["b"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node], + "test_params_names_start_with_onnx", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("onnx::x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_params_names_start_with_onnx") + check_correctness(model) + + if __name__ == "__main__": tvm.testing.main() From 72b75fe5b2f34765892b6ae3ba8709bad318b7bd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 Sep 2024 08:34:17 -0500 Subject: [PATCH 130/202] [Relax] Validate StructInfo of variable bindings (#17332) * [Relax] Validate StructInfo of variable bindings In Relax, both the variable and the expression in a `VarBinding` may contain `StructInfo` annotations. Prior to this commit, these `StructInfo` annotations could be inconsistent, assigning an expression to a variable of incompatible type. This commit updates the Relax well-formed checker to verify that the `StructInfo` of Relax variables accurately describes their contents. * Fix unit tests * [Relax][Bugfix] LCA of PrimStructInfo must check known values The `StructInfoLCA` determines the lowest common ancestor between two `StructInfo` annotations. This is primarily used in Relax to determine the appropriate `StructInfo` annotation for a `relax::If` node, given the `StructInfo` of each branch. Prior to this commit, when determining the LCA of two `PrimStructInfo` annotations, the `StructInfoLCA` function only inspected the datatype of `PrimStructInfo` annotations, and did not check for known values. For example, the LCA of `R.Prim(value=T.int64(128))` and `R.Prim(value=T.int64(64))` is `R.Prim("int64")`, but was incorrectly determined as `R.Prim(value=T.int64(128))` by the `StructInfoLCA` function. This commit updates `StructInfoLCA` to inspect the known values of a `PrimStructInfo`, as well as the datatype. --- src/relax/analysis/struct_info_analysis.cc | 23 ++++- src/relax/analysis/well_formed.cc | 12 +++ src/relax/transform/normalize.cc | 6 +- .../test_analysis_struct_info_analysis.py | 94 ++++++++++++++++++- .../python/relax/test_analysis_well_formed.py | 87 +++++++++++++++++ 5 files changed, 216 insertions(+), 6 deletions(-) diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index a7e5404c20ce..6fe8f36020bf 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -982,10 +982,25 @@ class StructInfoLCAFinder StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); - if (lhs->dtype == rhs->dtype) return GetRef(lhs); - // PrimType will be treated as their boxed(object) values - // as a result we can unify to object. - return ObjectStructInfo(lhs->span); + if (lhs->dtype != rhs->dtype) { + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + if (!lhs->value.defined() || !rhs->value.defined() || + !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) { + // The two values are known to contain the same dtype, but may + // contain different values. + if (!lhs->value.defined()) { + // If the mismatch was due to extra information in the RHS, + // prefer to avoid constructing a new object. + return GetRef(lhs); + } else { + return PrimStructInfo(lhs->dtype, lhs->span); + } + } + + return GetRef(lhs); } StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 235059ece2aa..7688c4a64291 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -429,6 +429,18 @@ class WellFormedChecker : public relax::ExprVisitor, } this->VisitVarDef(binding->var); + + if (check_struct_info_ && binding->var->struct_info_.defined() && + binding->value->struct_info_.defined()) { + auto expr_sinfo = GetStructInfo(binding->value); + auto var_sinfo = GetStructInfo(binding->var); + if (!IsBaseOf(var_sinfo, expr_sinfo)) { + Malformed(Diagnostic::Error(binding->var) + << "Expression of type " << expr_sinfo + << " cannot be assigned to a variable of type " << var_sinfo); + } + } + if (is_lambda) { recur_vars_.erase(binding->var); } diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 89080ebc3eb1..5493b44f822b 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -65,7 +65,11 @@ class NormalizeMutator : public ExprMutatorBase { Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { builder_->BeginBindingBlock(); - builder_->BeginScope(params); + if (params.defined()) { + builder_->BeginScope(params); + } else { + builder_->BeginInnerScope(); + } Expr ret = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index 83b1ddd4fc9e..b2931549e92b 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -24,7 +24,7 @@ from tvm import TVMError from tvm import relax as rx from tvm import tir, ir -from tvm.script import relax as R +from tvm.script import relax as R, tir as T def test_get_static_type_basic(): @@ -620,6 +620,98 @@ def fn_info_erased(): _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) +def _generate_prim_test_cases(): + dtypes = [ + "bool", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + "float16", + "float32", + "float64", + ] + + for dtype in dtypes: + # LCA of a PrimStructInfo with itself yields itself + yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype)) + + # The LCA of two values, each statically known to be the same + # value, is known to have that value. + yield ( + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(0, dtype)), + ) + + # The LCA of two values, each of which is statically known to + # have a different value, no longer knows the contained value. + yield ( + R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tir.const(1, dtype)), + R.Prim(dtype=dtype), + ) + + # LCA of a known variable with itself yields itself + var_N = tir.Var("N", dtype) + yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N)) + + # LCA of a known variable with a known static value is no + # longer known to have a specific value. + yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype)) + yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype)) + + var_M = tir.Var("M", dtype) + yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype)) + + for dtype_a in dtypes: + for dtype_b in dtypes: + if dtype_a != dtype_b: + # Unlike R.Tensor, R.Prim does not currently support a + # value with an unknown datatype. If the dtype + # differs between the two annotations, the next wider + # category is R.Object. + yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object) + + # Because the dtypes are different, even `R.Prim` containing + # the same value in different representations (e.g. + # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`. + yield ( + R.Prim(value=tir.const(0, dtype_a)), + R.Prim(value=tir.const(0, dtype_b)), + R.Object, + ) + + # And the same is true for known variable values + var_N = tir.Var("N", dtype_a) + var_M = tir.Var("M", dtype_b) + yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object) + + +@pytest.mark.parametrize("test_case", list(_generate_prim_test_cases())) +def test_prim_struct_info_lca(test_case): + def _normalize_sinfo(sinfo): + if isinstance(sinfo, tvm.relax.StructInfo): + return sinfo + elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy): + return sinfo.as_struct_info() + elif callable(sinfo): + return sinfo() + else: + raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo") + + lhs, rhs, expected = map(_normalize_sinfo, test_case) + + lca = rx.analysis.struct_info_lca(lhs, rhs) + assert tvm.ir.structural_equal( + lca, expected + ), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}" + + def _generate_tir_var_test_cases(): n, m = tir.Var("n", "int64"), tir.Var("m", "int64") shape0 = rx.ShapeStructInfo([1, n, 3]) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index c0b962c3f3a0..3db3efee1afc 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -1208,5 +1208,92 @@ def add_one( assert rx.analysis.well_formed(Module) +def test_var_binding_must_have_compatible_struct_info(): + """Variables must accurately describe their contents + + To be well-formed, the inferred struct info must not conflict with + the StructInfo annotations. + + """ + + # The function is equivalent to the TVMScript below. However, + # TVMScript applies additional checks that would catch this error + # while parsing. In order to validate the well-formed checker + # itself, this test directly constructs the function withoutusing + # TVMScript, skipping the TVMScript-specific checks. + # + # @R.function + # def main( + # A: R.Tensor(shape=[128, 32], dtype="float32"), + # ): + # B: R.Tensor(shape=[128, 32], dtype="int32") = A + # return B + + param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32")) + var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32")) + binding = tvm.relax.VarBinding(var, param) + body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) + tvm.relax.expr._update_struct_info(body, var.struct_info) + main = tvm.relax.Function([param], body) + + assert not rx.analysis.well_formed(main) + + +def test_var_binding_may_have_less_constrained_struct_info(): + """StructInfo of variable may be less specific than expression + + The StructInfo annotation of a variable is not required to be an + exact match to the expression's StructInfo, and may provide less + specific information than the inference would provide. + + """ + + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + ): + B: R.Object = R.add(A, A) + return B + + assert isinstance( + Module["main"].body.blocks[0].bindings[0].var.struct_info, tvm.relax.ObjectStructInfo + ), "Validity of this test requires a variable with R.Object struct info" + + assert rx.analysis.well_formed(Module) + + +def test_var_binding_with_incomplete_struct_info_must_be_consistent(): + """StructInfo of variable must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + # The function is equivalent to the TVMScript below. However, + # TVMScript applies additional checks that would catch this error + # while parsing. In order to validate the well-formed checker + # itself, this test directly constructs the function withoutusing + # TVMScript, skipping the TVMScript-specific checks. + # + # @R.function + # def main( + # A: R.Tensor(shape=[128, 32], dtype="float32"), + # ): + # B: R.Tensor(ndim=3) = A + # return B + + param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32")) + var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32")) + binding = tvm.relax.VarBinding(var, param) + body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) + tvm.relax.expr._update_struct_info(body, var.struct_info) + main = tvm.relax.Function([param], body) + + assert not rx.analysis.well_formed(main) + + if __name__ == "__main__": tvm.testing.main() From 2c4afbb5eace6c52f30d35a5c70465ca63c27a0f Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Wed, 11 Sep 2024 09:55:35 -0400 Subject: [PATCH 131/202] =?UTF-8?q?[Relax][KV=20Cache]=20Refactor=20`=5Fat?= =?UTF-8?q?tention=5Fsequence=5Fprefill`=20function=20to=20=E2=80=A6=20(#1?= =?UTF-8?q?7362)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes batch_size from the function signature, instead of mapping it within the function body. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index ae0537f0d9af..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -1237,7 +1237,7 @@ def merge_state_inplace( def _attention_sequence_prefill( - batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 + h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 ): # pylint: disable=line-too-long LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -1264,6 +1264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches var_output: T.handle, # [total_len, h_q, d] var_lse: T.handle # [total_len, h_q] ): + batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype) From 38e726aab191d5c16a7d98b2191a5f97f7fef410 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 04:18:07 +0900 Subject: [PATCH 132/202] [Relax][PyTorch] Cleanup unary op converters (#17356) * classify into 9 types of ops * introduce `_unary_op()` * cleanup `_clamp()` * cleanup `_gelu()` * cleanup `_hardsigmoid()` and `_hardswish()` * cleanup `_leakyrelu()` * cleanup `_log_softmax()` * cleanup `_round()` * cleanup `_softmax()` * cleanup `_tril_triu()` * replace `fx.node.Node` with `fx.Node` --- .../tvm/relax/frontend/torch/fx_translator.py | 566 +++++++++--------- 1 file changed, 288 insertions(+), 278 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index aed38d7c49ea..8d66343254c1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -35,7 +35,7 @@ def __init__(self) -> None: import torch # type: ignore from torch import fx - self.env: Dict[fx.node.Node, relax.Expr] = {} + self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.named_modules: Dict[str, torch.Module] = None self.block_builder: relax.BlockBuilder = None @@ -108,7 +108,7 @@ def retrieve_args(self, node): def _retrieve_args(self, node): from torch import fx - if isinstance(node, fx.node.Node): + if isinstance(node, fx.Node): return self.env[node] elif isinstance(node, tuple): return tuple(self._retrieve_args(x) for x in node) @@ -136,33 +136,113 @@ def _call_binary_op(self, op, lhs, rhs): lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) return self.block_builder.emit(op(lhs, rhs)) - ########## Arithmetic ########## + ########## Unary Ops ########## - def _exp(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) + def _unary_op(self, op: Callable) -> Callable: + from torch import fx - def _sigmoid(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) + def convert(node: fx.Node) -> relax.Var: + return self.block_builder.emit(op(self.env[node.args[0]])) - def _sqrt(self, node: fx.node.Node) -> relax.Expr: - arg = self.env[node.args[0]] - if isinstance(arg, (int, float)): - arg = relax.const(arg, "float32") - return self.block_builder.emit(relax.op.sqrt(arg)) + return convert - def _rsqrt(self, node: fx.node.Node) -> relax.Expr: - arg = self.env[node.args[0]] - if isinstance(arg, (int, float)): - arg = relax.const(arg, "float32") - return self.block_builder.emit(relax.op.rsqrt(arg)) + def _clamp(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = args[1] if len(args) > 1 else node.kwargs["min"] + a_max = args[2] if len(args) > 2 else node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + def _gelu(self, node: fx.Node) -> relax.Expr: + approximate = node.kwargs.get("approximate", "none") + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + + def _hardsigmoid(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + + def _hardswish(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + + def _leakyrelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _leakyrelu_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + alpha = module.negative_slope + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _log_softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _log_softmax_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + dim = module.dim + assert dim is not None + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _round(self, node: fx.node.Node) -> relax.Expr: - if "decimals" in node.kwargs and node.kwargs["decimals"] != 0: + def _round(self, node: fx.Node) -> relax.Expr: + if node.kwargs.get("decimals", 0) != 0: raise ValueError("specifying decimals for round is not supported yet") arg = self.env[node.args[0]] return self.block_builder.emit(relax.op.round(arg)) - def _add(self, node: fx.node.Node) -> relax.Expr: + def _softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _softmax_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + dim = module.dim + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + + ########## Arithmetic ########## + + def _add(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.add, lhs, rhs) @@ -176,103 +256,54 @@ def _add(self, node: fx.node.Node) -> relax.Expr: ) return lhs + rhs - def _max(self, node: fx.node.Node) -> relax.Expr: + def _max(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.maximum, lhs, rhs) - def _floordiv(self, node: fx.node.Node) -> relax.Expr: + def _floordiv(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.floor_divide, lhs, rhs) return lhs // rhs - def _mul(self, node: fx.node.Node) -> relax.Expr: + def _mul(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.multiply, lhs, rhs) return lhs * rhs - def _pow(self, node: fx.node.Node) -> relax.Expr: + def _pow(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.power, lhs, rhs) return lhs**rhs - def _neg(self, node: fx.node.Node) -> relax.Expr: - x = self.env[node.args[0]] - return self.block_builder.emit(relax.op.negative(x)) - - def _sub(self, node: fx.node.Node) -> relax.Expr: + def _sub(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.subtract, lhs, rhs) return lhs - rhs - def _truediv(self, node: fx.node.Node) -> relax.Expr: + def _truediv(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.divide, lhs, rhs) return lhs / rhs - def _clamp(self, node: fx.node.Node) -> relax.Expr: - args = self.retrieve_args(node) - a_min = node.kwargs["min"] - a_max = node.kwargs["max"] - if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" - ) - if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" - ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) - - def _gelu(self, node: fx.node.Node) -> relax.Expr: - if "approximate" not in node.kwargs: - approximate = "none" - else: - approximate = node.kwargs["approximate"] - if approximate == "none": - return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) - elif approximate == "tanh": - return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) - else: - raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) - - def _hardsigmoid(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) - - def _hardswish(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - x2 = relax.op.divide(x1, relax.const(6, dtype)) - return self.block_builder.emit(relax.op.multiply(x, x2)) - ########## Compare ########## - def _lt(self, node: fx.node.Node) -> relax.Expr: + def _lt(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.less, lhs, rhs) - def _eq(self, node: fx.node.Node) -> relax.Expr: + def _eq(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.equal, lhs, rhs) ########## Creation ########## - def _arange(self, node: fx.node.Node) -> relax.Var: + def _arange(self, node: fx.Node) -> relax.Var: import torch start_end_step = [None, None, None] @@ -311,15 +342,15 @@ def _arange(self, node: fx.node.Node) -> relax.Var: else: dtype = "int64" start_end_step = [ - self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in start_end_step + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step ] return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - def _empty(self, node: fx.node.Node) -> relax.Var: + def _empty(self, node: fx.Node) -> relax.Var: dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args, dtype)) - def _inplace_fill(self, node: fx.node.Node) -> relax.Var: + def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] dtype = x.struct_info.dtype @@ -328,7 +359,7 @@ def _inplace_fill(self, node: fx.node.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _tensor(self, node: fx.node.Node) -> relax.Var: + def _tensor(self, node: fx.Node) -> relax.Var: dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None if isinstance(node.args[0], float): return relax.const(node.args[0], dtype if dtype is not None else "float32") @@ -336,21 +367,10 @@ def _tensor(self, node: fx.node.Node) -> relax.Var: return relax.const(node.args[0], dtype if dtype is not None else "int64") raise ValueError("torch.tensor with value not a float or int is not accepted") - def _tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - return self.block_builder.emit(op(x, k)) - - return convert - def _inplace_tril_triu(self, op: Callable) -> Callable: from torch import fx - def convert(node: fx.node.Node) -> relax.Var: + def convert(node: fx.Node) -> relax.Var: x = self.env[node.args[0]] k = node.args[1] if len(node.args) > 1 else 0 assert isinstance(k, int) @@ -361,7 +381,7 @@ def convert(node: fx.node.Node) -> relax.Var: return convert - def _new_ones(self, node: fx.node.Node) -> relax.Var: + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] size = args[1:] @@ -376,7 +396,7 @@ def _new_ones(self, node: fx.node.Node) -> relax.Var: ) ) - def _ones(self, node: fx.node.Node) -> relax.Var: + def _ones(self, node: fx.Node) -> relax.Var: import torch args = self.retrieve_args(node) @@ -397,7 +417,7 @@ def _ones(self, node: fx.node.Node) -> relax.Var: ) ) - def _full(self, node: fx.node.Node) -> relax.Var: + def _full(self, node: fx.Node) -> relax.Var: import torch args = self.retrieve_args(node) @@ -421,14 +441,14 @@ def _full(self, node: fx.node.Node) -> relax.Var: ########## Statistical ########## - def _sum(self, node: fx.node.Node) -> relax.Var: + def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False if len(args) == 1: return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) return self.block_builder.emit(relax.op.sum(args[0], args[1])) - def _mean(self, node: fx.node.Node) -> relax.Var: + def _mean(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False if len(args) == 1: @@ -437,18 +457,18 @@ def _mean(self, node: fx.node.Node) -> relax.Var: ########## DataType ########## - def _float(self, node: fx.node.Node) -> relax.Var: + def _float(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - def _half(self, node: fx.node.Node) -> relax.Var: + def _half(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - def _type(self, node: fx.node.Node) -> relax.Var: + def _type(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) return self.block_builder.emit(relax.op.astype(x, dtype)) - def _to(self, node: fx.node.Node) -> relax.Var: + def _to(self, node: fx.Node) -> relax.Var: import torch x = self.env[node.args[0]] @@ -466,7 +486,7 @@ def _to(self, node: fx.node.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _matmul(self, node: fx.node.Node) -> relax.Var: + def _matmul(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) res = self._matmul_impl( args[0], @@ -474,7 +494,7 @@ def _matmul(self, node: fx.node.Node) -> relax.Var: ) return res - def _addmm(self, node: fx.node.Node) -> relax.Var: + def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] z = self.env[node.args[2]] @@ -496,7 +516,7 @@ def _addmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res - def _baddbmm(self, node: fx.node.Node) -> relax.Var: + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] a = self.env[node.args[1]] b = self.env[node.args[2]] @@ -518,7 +538,7 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res - def _einsum(self, node: fx.node.Node) -> relax.Var: + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -526,7 +546,7 @@ def _einsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) - def _unbind(self, node: fx.node.Node) -> relax.Var: + def _unbind(self, node: fx.Node) -> relax.Var: if len(node.args) == 2: assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" dim = node.args[1] @@ -544,12 +564,12 @@ def _unbind(self, node: fx.node.Node) -> relax.Var: ########## Manipulation ########## - def _cat(self, node: fx.node.Node) -> relax.Var: + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _expand(self, node: fx.node.Node) -> relax.Var: + def _expand(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) broadcast_shape, in_shape = [], self.shape_of(args[0]) for idx, i in enumerate(args[1:]): @@ -559,7 +579,7 @@ def _expand(self, node: fx.node.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - def _flatten(self, node: fx.node.Node) -> relax.Var: + def _flatten(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -579,7 +599,7 @@ def _flatten(self, node: fx.node.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.reshape(x, new_shape)) - def _permute(self, node: fx.node.Node) -> relax.Var: + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -587,7 +607,7 @@ def _permute(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - def _reshape(self, node: fx.node.Node) -> relax.Var: + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -595,7 +615,7 @@ def _reshape(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - def _split(self, node: fx.node.Node) -> relax.Var: + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] if "dim" in node.kwargs: @@ -611,7 +631,7 @@ def _split(self, node: fx.node.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _chunk(self, node: fx.node.Node) -> relax.Var: + def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] @@ -623,13 +643,13 @@ def _chunk(self, node: fx.node.Node) -> relax.Var: dim = 0 return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _transpose(self, node: fx.node.Node) -> relax.Var: + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _squeeze(self, node: fx.node.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if "dim" in node.kwargs: @@ -640,7 +660,7 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) - def _repeat(self, node: fx.node.Node) -> relax.Var: + def _repeat(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -648,7 +668,7 @@ def _repeat(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _tile(self, node: fx.node.Node) -> relax.Var: + def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -656,7 +676,7 @@ def _tile(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _cumsum(self, node: fx.node.Node) -> relax.Var: + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if "dim" in node.kwargs: @@ -674,13 +694,13 @@ def _cumsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - def _index_select(self, node: fx.node.Node) -> relax.Var: + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] index = self.env[node.args[2]] return self.block_builder.emit(relax.op.take(x, index, dim)) - def _masked_fill(self, node: fx.node.Node) -> relax.Var: + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] value = node.args[2] @@ -688,7 +708,7 @@ def _masked_fill(self, node: fx.node.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: + def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] value = node.args[2] @@ -703,7 +723,7 @@ def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: def _argmax_argmin(self, op: Callable) -> Callable: from torch import fx - def convert(node: fx.node.Node): + def convert(node: fx.Node): x = self.env[node.args[0]] dim = None keepdims = False @@ -726,14 +746,14 @@ def convert(node: fx.node.Node): ########## Neural Network ########## - def _linear(self, node: fx.node.Node) -> relax.Var: + def _linear(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = None if module.bias is None else self.params[module.bias] return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _linear_functional(self, node: fx.node.Node) -> relax.Var: + def _linear_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -770,7 +790,7 @@ def _conv1d_impl( bias = relax.op.reshape(bias, (1, -1, 1)) return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv1d(self, node: fx.node.Node) -> relax.Var: + def _conv1d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -788,7 +808,7 @@ def _conv1d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv1d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -838,7 +858,7 @@ def _conv1d_transpose_impl( bias = relax.op.reshape(bias, (1, -1, 1)) return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv1d_transpose(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -856,7 +876,7 @@ def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -905,7 +925,7 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) - def _conv2d(self, node: fx.node.Node) -> relax.Var: + def _conv2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -923,7 +943,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv2d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -973,7 +993,7 @@ def _conv2d_transpose_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv2d_transpose(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -991,7 +1011,7 @@ def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -1040,7 +1060,7 @@ def _conv3d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv3d(self, node: fx.node.Node) -> relax.Var: + def _conv3d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1058,7 +1078,7 @@ def _conv3d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv3d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -1077,7 +1097,7 @@ def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + def _max_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1108,7 +1128,7 @@ def _max_pool2d(self, node: fx.node.Node) -> relax.Var: ) ) - def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: + def _avg_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1154,7 +1174,7 @@ def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: from torch import fx - def _impl(node: fx.node.Node) -> relax.Var: + def _impl(node: fx.Node) -> relax.Var: if is_module: module = self.named_modules[node.target] x = self.env[node.args[0]] @@ -1168,7 +1188,7 @@ def _impl(node: fx.node.Node) -> relax.Var: return _impl - def _softmax(self, node: fx.node.Node) -> relax.Var: + def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1179,29 +1199,7 @@ def _softmax(self, node: fx.node.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - def _log_softmax(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - dim = module.dim - else: - nargs = len(node.args) - dim = node.args[1] if nargs > 1 else node.kwargs["dim"] - assert dim is not None - return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - - def _leakyrelu(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - alpha = module.negative_slope - else: - nargs = len(node.args) - alpha = node.args[1] if nargs > 1 else node.kwargs["negative_slope"] - assert alpha is not None - return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - - def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + def _batch_norm_2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1224,7 +1222,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _layer_norm(self, node: fx.node.Node) -> relax.Var: + def _layer_norm(self, node: fx.Node) -> relax.Var: import torch # type: ignore from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore @@ -1291,7 +1289,7 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: ) ) - def _group_norm(self, node: fx.node.Node) -> relax.Var: + def _group_norm(self, node: fx.Node) -> relax.Var: import torch # type: ignore x = self.env[node.args[0]] @@ -1317,7 +1315,7 @@ def _group_norm(self, node: fx.node.Node) -> relax.Var: ) ) - def _embedding(self, node: fx.node.Node) -> relax.Var: + def _embedding(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1333,7 +1331,7 @@ def _embedding(self, node: fx.node.Node) -> relax.Var: embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _interpolate(self, node: fx.node.Node) -> relax.Var: + def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, # recompute_scale_factor=None, antialias=False) @@ -1407,7 +1405,7 @@ def _interpolate(self, node: fx.node.Node) -> relax.Var: ) ) - def _cross_entropy(self, node: fx.node.Node) -> relax.Expr: + def _cross_entropy(self, node: fx.Node) -> relax.Expr: preds = self.env[node.args[0]] targets = self.env[node.args[1]] @@ -1442,7 +1440,7 @@ def _cross_entropy(self, node: fx.node.Node) -> relax.Expr: ) ) - def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: assert ( len(node.args) <= 4 ), "Dropout is not supported, and is_causal should be called by kwargs." @@ -1464,13 +1462,13 @@ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: ########## Others ########## - def _sym_size_int(self, node: fx.node.Node) -> relax.Expr: + def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) idx = node.args[1] return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def _size(self, node: fx.node.Node) -> relax.Expr: + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) if len(node.args) == 1: @@ -1480,7 +1478,7 @@ def _size(self, node: fx.node.Node) -> relax.Expr: idx = node.args[1] return self.shape_of(x)[idx].value - def _getattr(self, node: fx.node.Node) -> relax.Var: + def _getattr(self, node: fx.Node) -> relax.Var: if isinstance(self.env[node.args[0]], relax.Expr): if node.args[1] == "dtype": return self.env[node.args[0]].struct_info.dtype @@ -1488,7 +1486,7 @@ def _getattr(self, node: fx.node.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _getitem(self, node: fx.node.Node) -> relax.Var: + def _getitem(self, node: fx.Node) -> relax.Var: import torch x = self.env[node.args[0]] @@ -1510,7 +1508,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: shape = self.shape_of(x) non_ellipsis_cnt = 0 for index in node.args[1]: - if isinstance(index, (int, slice, torch.fx.node.Node)): + if isinstance(index, (int, slice, torch.fx.Node)): non_ellipsis_cnt += 1 for index in node.args[1]: if isinstance(index, int): @@ -1534,7 +1532,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: stride.append(1) stride_axes.append(i) i += 1 - elif isinstance(index, torch.fx.node.Node): + elif isinstance(index, torch.fx.Node): node_index = self.env[index] if not isinstance(node_index, relax.Expr): raise ValueError( @@ -1573,142 +1571,154 @@ def create_convert_map(self): from torch import nn from torch import fx - self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { - # call_module - nn.Linear: self._linear, + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], relax.Var]] = { + ## call_module + # unary + nn.Dropout: lambda node: self.env[node.args[0]], + nn.GELU: self._gelu, + nn.Hardsigmoid: self._hardsigmoid, + nn.Hardswish: self._hardswish, + nn.Identity: lambda node: self.env[node.args[0]], + nn.LeakyReLU: self._leakyrelu_module, + nn.LogSoftmax: self._log_softmax_module, + nn.ReLU: self._unary_op(relax.op.nn.relu), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.Sigmoid: self._unary_op(relax.op.sigmoid), + nn.SiLU: self._unary_op(relax.op.nn.silu), + nn.Softmax: self._softmax_module, + nn.Tanh: self._unary_op(relax.op.tanh), + # neural network + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.AvgPool2d: self._avg_pool2d, + nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, nn.Conv3d: self._conv3d, nn.ConvTranspose1d: self._conv1d_transpose, nn.ConvTranspose2d: self._conv2d_transpose, - nn.MaxPool2d: self._max_pool2d, - nn.AvgPool2d: self._avg_pool2d, - nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), - nn.Softmax: self._softmax, - nn.LogSoftmax: self._log_softmax, - nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), - nn.LeakyReLU: self._leakyrelu, - nn.ReLU6: lambda node: self.block_builder.emit( - relax.op.clip(self.env[node.args[0]], 0, 6) - ), - nn.GELU: self._gelu, - nn.Sigmoid: self._sigmoid, - nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), - nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), - nn.Hardsigmoid: self._hardsigmoid, - nn.Hardswish: self._hardswish, - nn.Flatten: self._flatten, - nn.BatchNorm2d: self._batch_norm_2d, - nn.LayerNorm: self._layer_norm, + nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm, - nn.Dropout: lambda node: self.env[node.args[0]], - nn.Identity: lambda node: self.env[node.args[0]], + nn.LayerNorm: self._layer_norm, + nn.Linear: self._linear, + nn.MaxPool2d: self._max_pool2d, nn.modules.sparse.Embedding: self._embedding, - nn.CrossEntropyLoss: self._cross_entropy, - # call_function and call_method - "sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])), - "cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])), - "tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])), - "asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])), - "acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])), - "atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])), - "sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])), - "cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])), - "tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), - "asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])), - "acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])), - "atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])), - "exp": self._exp, - "iadd": self._add, + # tensor manipulation + nn.Flatten: self._flatten, + ## call_function and call_method + # unary + "acos": self._unary_op(relax.op.acos), + "acosh": self._unary_op(relax.op.acosh), + "asin": self._unary_op(relax.op.asin), + "asinh": self._unary_op(relax.op.asinh), + "atan": self._unary_op(relax.op.atan), + "atanh": self._unary_op(relax.op.atanh), + "clamp": self._clamp, + "cos": self._unary_op(relax.op.cos), + "cosh": self._unary_op(relax.op.cosh), + "dropout": lambda node: self.env[node.args[0]], + "exp": self._unary_op(relax.op.exp), + "gelu": self._gelu, + "hardsigmoid": self._hardsigmoid, + "hardswish": self._hardswish, + "leaky_relu": self._leakyrelu, + "log_softmax": self._log_softmax, + "neg": self._unary_op(relax.op.negative), + "relu": self._unary_op(relax.op.nn.relu), + "round": self._round, + "rsqrt": self._unary_op(relax.op.rsqrt), + "sigmoid": self._unary_op(relax.op.sigmoid), + "silu": self._unary_op(relax.op.nn.silu), + "sin": self._unary_op(relax.op.sin), + "sinh": self._unary_op(relax.op.sinh), + "softmax": self._softmax, + "sqrt": self._unary_op(relax.op.sqrt), + "tan": self._unary_op(relax.op.tan), + "tanh": self._unary_op(relax.op.tanh), + "tril_": self._inplace_tril_triu(relax.op.tril), + "tril": self._tril_triu(relax.op.tril), + "triu_": self._inplace_tril_triu(relax.op.triu), + "triu": self._tril_triu(relax.op.triu), + # binary "add": self._add, + "eq": self._eq, "floordiv": self._floordiv, + "iadd": self._add, + "lt": self._lt, + "matmul": self._matmul, + "max": self._max, "mul": self._mul, - "sub": self._sub, "pow": self._pow, - "sigmoid": self._sigmoid, - "sqrt": self._sqrt, - "round": self._round, - "lt": self._lt, - "eq": self._eq, + "sub": self._sub, "truediv": self._truediv, - "fill_": self._inplace_fill, - "new_ones": self._new_ones, - "arange": self._arange, - "empty": self._empty, - "tensor": self._tensor, - "tril": self._tril_triu(relax.op.tril), - "triu": self._tril_triu(relax.op.triu), - "tril_": self._inplace_tril_triu(relax.op.tril), - "triu_": self._inplace_tril_triu(relax.op.triu), - "sum": self._sum, - "float": self._float, - "half": self._half, - "type": self._type, - "astype": self._type, - "matmul": self._matmul, - "conv1d": self._conv1d_functional, + # neural network + "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "addmm": self._addmm, + "avg_pool2d": self._avg_pool2d, + "baddbmm": self._baddbmm, + "bmm": self._matmul, "conv_transpose1d": self._conv1d_transpose_functional, - "conv2d": self._conv2d_functional, "conv_transpose2d": self._conv2d_transpose_functional, + "conv1d": self._conv1d_functional, + "conv2d": self._conv2d_functional, "conv3d": self._conv3d_functional, + "cross_entropy": self._cross_entropy, + "einsum": self._einsum, + "interpolate": self._interpolate, + "layer_norm": self._layer_norm, "linear": self._linear_functional, - "addmm": self._addmm, - "baddbmm": self._baddbmm, - "bmm": self._matmul, + "max_pool2d": self._max_pool2d, + "scaled_dot_product_attention": self._scaled_dot_product_attention, + "stochastic_depth": lambda node: self.env[node.args[0]], + "unbind": self._unbind, + # statistical + "mean": self._mean, + "sum": self._sum, + # search + "argmax": self._argmax_argmin(relax.op.argmax), + "argmin": self._argmax_argmin(relax.op.argmin), + # tensor manipulation "cat": self._cat, "concat": self._cat, + "contiguous": lambda node: self.env[node.args[0]], + "cumsum": self._cumsum, "expand": self._expand, "flatten": self._flatten, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, + "size": self._size, "split": self._split, + "squeeze": self._squeeze, "tile": self._tile, - "cumsum": self._cumsum, - "chunk": self._chunk, "transpose": self._transpose, - "squeeze": self._squeeze, "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), "view": self._reshape, - "argmax": self._argmax_argmin(relax.op.argmax), - "argmin": self._argmax_argmin(relax.op.argmin), - "softmax": self._softmax, - "log_softmax": self._log_softmax, - "dropout": lambda node: self.env[node.args[0]], - "stochastic_depth": lambda node: self.env[node.args[0]], - "clamp": self._clamp, - "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), - "leaky_relu": self._leakyrelu, - "gelu": self._gelu, - "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), - "hardsigmoid": self._hardsigmoid, - "hardswish": self._hardswish, - "interpolate": self._interpolate, - "sym_size.int": self._sym_size_int, - "size": self._size, - "getattr": self._getattr, - "getitem": self._getitem, - "contiguous": lambda node: self.env[node.args[0]], - "to": self._to, - "max_pool2d": self._max_pool2d, - "avg_pool2d": self._avg_pool2d, - "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), - "layer_norm": self._layer_norm, + # tensor creation + "arange": self._arange, + "chunk": self._chunk, + "empty": self._empty, + "fill_": self._inplace_fill, + "full": self._full, "index_select": self._index_select, + "masked_fill_": self._inplace_masked_fill, "masked_fill": self._masked_fill, + "new_ones": self._new_ones, "ones": self._ones, - "full": self._full, - "masked_fill_": self._inplace_masked_fill, - "mean": self._mean, - "rsqrt": self._rsqrt, - "neg": self._neg, - "max": self._max, - "cross_entropy": self._cross_entropy, - "scaled_dot_product_attention": self._scaled_dot_product_attention, - "einsum": self._einsum, - "unbind": self._unbind, + "tensor": self._tensor, + "to": self._to, + # datatype + "astype": self._type, + "float": self._float, + "half": self._half, + "type": self._type, + # other + "getattr": self._getattr, + "getitem": self._getitem, + "sym_size.int": self._sym_size_int, } def update_convert_map(self, custom_convert_map: dict): From 5265d215fe26df3172fa0375030802f90289fe53 Mon Sep 17 00:00:00 2001 From: Ivan Sidorenko <98739392+ibsidorenko@users.noreply.github.com> Date: Thu, 12 Sep 2024 01:16:56 +0300 Subject: [PATCH 133/202] [Relax] Add new NN allgather operator (#17359) This commit adds wrapper for Relax NCCL allgather operator. --- python/tvm/relax/frontend/nn/op.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 04c030bea6fa..4664ec549388 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1719,6 +1719,28 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name=" return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name) +def ccl_allgather(x: Tensor, num_workers: int, name="ccl_allgather"): + """CCL Allgather operator + + Parameters + ---------- + x : relax.Expr + The input tensor. + + num_workers : int + Number of workers. + + name : str + Name hint for this operation. + + Returns + ------- + result : Tensor + The result tensor of allgather. + """ + return wrap_nested(_op.ccl.allgather(x._expr, num_workers), name) + + def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): """Broadcast data from worker-0 to all other workers. From 31da94717377df367803c7c0ce8b3451b927a702 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 21:18:13 +0900 Subject: [PATCH 134/202] [Relax][PyTorch] Cleanup binary op converters (#17366) * introduce `_binary_op()` * cleanup --- .../tvm/relax/frontend/torch/fx_translator.py | 146 ++++++------------ 1 file changed, 49 insertions(+), 97 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8d66343254c1..7efc2412eaf7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -19,7 +19,7 @@ # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" from typing import Callable, Dict, List, Optional, Tuple, Union -from functools import reduce +from functools import partial, reduce import tvm from tvm import relax @@ -119,23 +119,6 @@ def _retrieve_args(self, node): else: return node - @staticmethod - def _promote_binary_op_args(lhs, rhs): - if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - return lhs, rhs - elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) - elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs - else: - assert False - - def _call_binary_op(self, op, lhs, rhs): - lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) - return self.block_builder.emit(op(lhs, rhs)) - ########## Unary Ops ########## def _unary_op(self, op: Callable) -> Callable: @@ -240,66 +223,38 @@ def convert(node: fx.Node) -> relax.Var: return convert - ########## Arithmetic ########## + ########## Binary Ops ########## - def _add(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.add, lhs, rhs) - elif isinstance(lhs, relax.expr.Constant): - return self._call_binary_op( - relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) - ) - elif isinstance(rhs, relax.expr.Constant): - return self._call_binary_op( - relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs - ) - return lhs + rhs - - def _max(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.maximum, lhs, rhs) - - def _floordiv(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.floor_divide, lhs, rhs) - return lhs // rhs - - def _mul(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.multiply, lhs, rhs) - return lhs * rhs - - def _pow(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.power, lhs, rhs) - return lhs**rhs - - def _sub(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.subtract, lhs, rhs) - return lhs - rhs - - def _truediv(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.divide, lhs, rhs) - return lhs / rhs - - ########## Compare ########## - - def _lt(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - return self._call_binary_op(relax.op.less, lhs, rhs) - - def _eq(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - return self._call_binary_op(relax.op.equal, lhs, rhs) + def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return call_binary_op(relax_op, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + elif isinstance(rhs, relax.expr.Constant): + return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + return intrinsic_op(lhs, rhs) + + return convert ########## Creation ########## @@ -486,14 +441,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _matmul(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - res = self._matmul_impl( - args[0], - args[1], - ) - return res - def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] @@ -1568,6 +1515,7 @@ def _getitem(self, node: fx.Node) -> relax.Var: assert False def create_convert_map(self): + import operator from torch import nn from torch import fx @@ -1641,23 +1589,27 @@ def create_convert_map(self): "triu_": self._inplace_tril_triu(relax.op.triu), "triu": self._tril_triu(relax.op.triu), # binary - "add": self._add, - "eq": self._eq, - "floordiv": self._floordiv, - "iadd": self._add, - "lt": self._lt, - "matmul": self._matmul, - "max": self._max, - "mul": self._mul, - "pow": self._pow, - "sub": self._sub, - "truediv": self._truediv, + "add": self._binary_op(relax.op.add, operator.add), + "eq": self._binary_op(relax.op.equal, operator.eq), + "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), + "iadd": self._binary_op(relax.op.add, operator.add), + "lt": self._binary_op(relax.op.less, operator.lt), + "matmul": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "max": self._binary_op(relax.op.maximum, max), + "mul": self._binary_op(relax.op.multiply, operator.mul), + "pow": self._binary_op(relax.op.power, operator.pow), + "sub": self._binary_op(relax.op.subtract, operator.sub), + "truediv": self._binary_op(relax.op.divide, operator.truediv), # neural network "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, - "bmm": self._matmul, + "bmm": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "conv_transpose1d": self._conv1d_transpose_functional, "conv_transpose2d": self._conv2d_transpose_functional, "conv1d": self._conv1d_functional, From 090430a284652057ea0f2c8909d2af0bea0e3454 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 12 Sep 2024 21:21:26 +0800 Subject: [PATCH 135/202] [DLight] Fix Matmul rule for Conv3D (#17363) Currently, the matmul rule for Conv3D is incorrect, due to the incorrect reindexing of the input tensor. This commit fixes the issue by correctly The `index map` of `transform_layout` should be calculated after the `reindex` process --- python/tvm/dlight/gpu/matmul.py | 100 ++++++++++++----------- tests/python/dlight/test_gpu_conv.py | 118 +++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 48 deletions(-) create mode 100644 tests/python/dlight/test_gpu_conv.py diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 5fb8e2469d54..5568083982b9 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Step 0. Configs block_size_x: int = 16 block_size_y: int = 16 @@ -382,12 +375,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring vector_size: int = 4 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -508,13 +508,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Start Schedule # Step 0. Get schedule config. # NOTE: we can analyze the config by the hardware spec in the future @@ -539,12 +532,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -729,13 +729,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Start Schedule # Step 0. Get schedule config. # NOTE: we can analyze the config by the hardware spec in the future @@ -760,12 +753,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -979,12 +979,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring main_block = reduction_blocks[0] block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None main_block_info = get_block_info(sch, main_block) iter_infos = main_block_info.iters + if not get_index_map(block_stmt): + return None # Checks if it's a inner reduction by getting the last matrix's inner Index def is_inner_reduction(block_stmt, iter_infos): @@ -1000,13 +999,18 @@ def is_inner_reduction(block_stmt, iter_infos): return ret # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + # Reindex first and than analyze the index map + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 1. Check Tensor Core support diff --git a/tests/python/dlight/test_gpu_conv.py b/tests/python/dlight/test_gpu_conv.py new file mode 100644 index 000000000000..4997975dd311 --- /dev/null +++ b/tests/python/dlight/test_gpu_conv.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("nvidia/geforce-gtx-1080-ti"): + # Use Matmul rule for Conv for now + return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + return transform + + +class TestConv3d(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before( + A: T.Buffer((14308, 3, 2, 14, 14), "float16"), + W: T.Buffer((1280, 3, 2, 14, 14), "float16"), + C: T.Buffer((14308, 1280, 1, 1, 1), "float16"), + ): + pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16") + for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14): + with T.block("pad_A"): + v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14): + with T.block("C"): + v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz]) + with T.init(): + C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0) + C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz] + + @T.prim_func + def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local") + pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared") + W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared") + for ax0_ax2_0_fused in T.thread_binding(20, thread="blockIdx.y"): + for ax1_0 in T.thread_binding(448, thread="blockIdx.x"): + for ax2_1 in T.thread_binding(1, thread="vthread.y"): + for ax1_1 in T.thread_binding(1, thread="vthread.x"): + for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): + for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_3_init, ax2_3_0_init in T.grid(4, 2): + for ax2_3_1_init in T.vectorized(2): + with T.block("C_init"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) + C_reindex_pad_local[0, v1, v2] = T.float16(0.0) + for ax3_0 in range(74): + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(2): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("pad_A_reindex_pad_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(4): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("W_reindex_pad_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) + for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): + for ax2_3_1 in T.vectorized(2): + with T.block("C_update"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) + v3 = T.axis.reduce(1184, ax3_0 * 16 + ax3_1) + C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3] + for ax0, ax1, ax2_0 in T.grid(1, 4, 2): + for ax2_1_1 in T.vectorized(2): + with T.block("C_reindex_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) + T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < 14308) + C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2] + # fmt: on + + +if __name__ == "__main__": + tvm.testing.main() From bd11e19490cb5f1a2081ac1787803428545e22a5 Mon Sep 17 00:00:00 2001 From: PatricYan Date: Fri, 13 Sep 2024 00:25:57 +0800 Subject: [PATCH 136/202] Update tvmc_command_line_driver.py, modify the sentence, remove the duplicate "as" (#17358) Update tvmc_command_line_driver.py, modify the sentence, remove the duplicate "as" --- gallery/tutorial/tvmc_command_line_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/tutorial/tvmc_command_line_driver.py b/gallery/tutorial/tvmc_command_line_driver.py index a20dcb9c96a4..58a8dc212d9f 100644 --- a/gallery/tutorial/tvmc_command_line_driver.py +++ b/gallery/tutorial/tvmc_command_line_driver.py @@ -47,7 +47,7 @@ # ---------- # # TVMC is a Python application, part of the TVM Python package. -# When you install TVM using a Python package, you will get TVMC as +# When you install TVM using a Python package, you will get TVMC # as a command line application called ``tvmc``. The location of this command # will vary depending on your platform and installation method. # From b8b5fb6a1c63bdd3409e2e266d2ac386f8fbbb26 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 12 Sep 2024 13:25:23 -0500 Subject: [PATCH 137/202] [IR] Expose ReplaceGlobalVars utility in the Python API (#17361) * [IR] Expose ReplaceGlobalVars utility in the Python API This is a follow-up PR to https://github.com/apache/tvm/pull/17202, which added a general utility to replace `GlobalVar` instances across all TVM IR types. This PR exposes this new utility through the Python API, and explicitly tests its functionality. * Lint fix --- ...ace_global_var.h => replace_global_vars.h} | 10 +- python/tvm/ir/module.py | 28 ++ ...e_global_var.cc => replace_global_vars.cc} | 43 ++- src/relax/transform/attach_global_symbol.cc | 4 +- ...e_global_var.cc => replace_global_vars.cc} | 23 +- ...e_global_var.cc => replace_global_vars.cc} | 20 +- .../ir/test_transform_replace_global_var.py | 306 ++++++++++++++++++ 7 files changed, 418 insertions(+), 16 deletions(-) rename include/tvm/ir/{replace_global_var.h => replace_global_vars.h} (85%) rename src/ir/{replace_global_var.cc => replace_global_vars.cc} (55%) rename src/relax/transform/{replace_global_var.cc => replace_global_vars.cc} (72%) rename src/tir/transforms/{replace_global_var.cc => replace_global_vars.cc} (75%) create mode 100644 tests/python/ir/test_transform_replace_global_var.py diff --git a/include/tvm/ir/replace_global_var.h b/include/tvm/ir/replace_global_vars.h similarity index 85% rename from include/tvm/ir/replace_global_var.h rename to include/tvm/ir/replace_global_vars.h index c15dd5f4e5ad..ea91d46d7c0a 100644 --- a/include/tvm/ir/replace_global_var.h +++ b/include/tvm/ir/replace_global_vars.h @@ -18,13 +18,13 @@ */ /*! - * \file tvm/ir/replace_global_var.h + * \file tvm/ir/replace_global_vars.h * * \brief A utility to replace GlobalVar instances across all TVM IR * types in an IRMdoule. */ -#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_ -#define TVM_IR_REPLACE_GLOBAL_VAR_H_ +#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_ +#define TVM_IR_REPLACE_GLOBAL_VARS_H_ #include @@ -41,7 +41,7 @@ namespace transform { * * \return The updated IRModule */ -TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map replacements); +TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map replacements); struct GlobalVarReplacer { using FType = NodeFunctor)>; @@ -54,4 +54,4 @@ struct GlobalVarReplacer { } // namespace transform } // namespace tvm -#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_ +#endif // TVM_IR_REPLACE_GLOBAL_VARS_H_ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index ea3ef6d8831b..3c76dbfdd839 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" + from __future__ import annotations from typing import Dict, Union @@ -216,6 +217,33 @@ def get_global_vars(self): """ return _ffi_api.Module_GetGlobalVars(self) + def replace_global_vars( + self, + replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]], + ) -> "IRModule": + """Replace GlobalVar instances within the module + + Replace GlobalVars within the IRModule. Since the IRModule + may contain internal references to a GlobalVar, either in TIR + or in Relax, this method should be used whenever replacing or + renaming a GlobalVar. + + Parameters + ---------- + replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]] + + A dictionary where each key is a GlobalVar to be replaced, + and the corresponding value is the GlobalVar with which to + replace it. + + Returns + ------- + IRModule + The updated module + + """ + return _ffi_api.Module_ReplaceGlobalVars(self, replacements) + def get_global_type_vars(self): """Collect all global type vars defined in this module. diff --git a/src/ir/replace_global_var.cc b/src/ir/replace_global_vars.cc similarity index 55% rename from src/ir/replace_global_var.cc rename to src/ir/replace_global_vars.cc index 08d66d0e7cf2..9607dab11a6a 100644 --- a/src/ir/replace_global_var.cc +++ b/src/ir/replace_global_vars.cc @@ -18,18 +18,22 @@ */ /*! - * \file src/ir/replace_global_var.cc + * \file src/ir/replace_global_vars.cc * \brief IRModule transform to replace GlobalVar instances across any IR type. */ -#include +#include #include namespace tvm { namespace transform { -IRModule ReplaceGlobalVar(IRModule mod, Map replacements) { +IRModule ReplaceGlobalVars(IRModule mod, Map replacements) { + if (replacements.empty()) { + return mod; + } + std::vector to_remove; IRModule updates; @@ -57,7 +61,38 @@ IRModule ReplaceGlobalVar(IRModule mod, Map replacements) return mod; } -TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar); +TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); + +IRModule ModuleReplaceGlobalVars( + IRModule mod, Map, Variant> replacements) { + Map gvar_replacements; + for (const auto& [before, after] : replacements) { + GlobalVar gvar_before; + if (auto gvar = before.as()) { + gvar_before = gvar.value(); + } else if (auto str = before.as()) { + gvar_before = mod->GetGlobalVar(str.value()); + } else { + LOG(FATAL) << "Variant must contain either String or GlobalVar"; + } + + GlobalVar gvar_after; + if (auto gvar = after.as()) { + gvar_after = gvar.value(); + } else if (auto str = after.as()) { + gvar_after = gvar_before; + gvar_after.CopyOnWrite()->name_hint = str.value(); + } else { + LOG(FATAL) << "Variant must contain either String or GlobalVar"; + } + + gvar_replacements.Set(gvar_before, gvar_after); + } + + return ReplaceGlobalVars(mod, gvar_replacements); +} + +TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index a517d5a035e2..6f18339436fb 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -22,7 +22,7 @@ */ #include -#include +#include #include #include #include @@ -72,7 +72,7 @@ Pass AttachGlobalSymbol() { mod.CopyOnWrite()->Update(updates); if (gvar_updates.size()) { - mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates); + mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates); } } return mod; diff --git a/src/relax/transform/replace_global_var.cc b/src/relax/transform/replace_global_vars.cc similarity index 72% rename from src/relax/transform/replace_global_var.cc rename to src/relax/transform/replace_global_vars.cc index b81b831036ff..ea5d5e18d8ff 100644 --- a/src/relax/transform/replace_global_var.cc +++ b/src/relax/transform/replace_global_vars.cc @@ -19,13 +19,13 @@ /*! * - * \file src/relax/transform/replace_global_var.cc + * \file src/relax/transform/replace_global_vars.cc * * \brief GlobalVar replacement across IR types */ #include -#include +#include #include #include #include @@ -53,7 +53,24 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, Map replacements) -> BaseFunc { Mutator mutator(replacements); - return Downcast(mutator(Downcast(func))); + auto new_func = Downcast(mutator(Downcast(func))); + + // If the function is externally exposed, and is being replaced + // by a GlobalVar with a new name, then the function's + // kGlobalSymbol must be updated to match. + if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + + return new_func; }); TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) diff --git a/src/tir/transforms/replace_global_var.cc b/src/tir/transforms/replace_global_vars.cc similarity index 75% rename from src/tir/transforms/replace_global_var.cc rename to src/tir/transforms/replace_global_vars.cc index 8ef8ba9276b0..3e8437063775 100644 --- a/src/tir/transforms/replace_global_var.cc +++ b/src/tir/transforms/replace_global_vars.cc @@ -19,12 +19,12 @@ /*! * - * \file src/tir/transforms/replace_global_var.cc + * \file src/tir/transforms/replace_global_vars.cc * * \brief GlobalVar replacement across IR types */ -#include +#include #include #include @@ -61,6 +61,22 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) if (!new_body.same_as(func->body)) { func.CopyOnWrite()->body = new_body; } + + // If the function is externally exposed, and is being replaced + // by a GlobalVar with a new name, then the function's + // kGlobalSymbol must be updated to match. + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + return func; }); diff --git a/tests/python/ir/test_transform_replace_global_var.py b/tests/python/ir/test_transform_replace_global_var.py new file mode 100644 index 000000000000..d31993141500 --- /dev/null +++ b/tests/python/ir/test_transform_replace_global_var.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +def _get_before_module(): + @I.ir_module + class Module: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Module.relax_subroutine(A) + C = R.call_tir(Module.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Module.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Module.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + return Module + + +def test_no_op_if_no_replacements(): + """If no replacements are performed, the IRModule is unmodified""" + + before = _get_before_module() + expected = before + + after = before.replace_global_vars({}) + + tvm.ir.assert_structural_equal(expected, after) + assert before.same_as(after) + + +def test_replace_relax_main(): + """An externally-exposed Relax function may be replaced + + In this example, the "relax_main" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the "global_symbol" attribute of the + externally-exposed function. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"relax_main": "relax_main_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_relax_subroutine(): + """An internal Relax function may be replaced + + In this example, the "relax_subroutine" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the GlobalVar used to call the subroutine within + "relax_main". The "global_symbol" attribute does not need to be + updated, because internal functions do not have this attribute. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"relax_subroutine": "relax_subroutine_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine_with_new_name(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine_with_new_name( + A: R.Tensor([16], "float32"), + ) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_tir_main(): + """An externally-exposed TIR function may be replaced + + In this example, the "tir_main" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, the "global_symbol" attribute of the externally-exposed + function. In addition, calls to the TIR function should be + updated to use the new GlobalVar. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"tir_main": "tir_main_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main_with_new_name(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_tir_subroutine(): + """An internally-exposed TIR function may be replaced + + In this example, the "tir_subroutine" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the GlobalVar used to refer to it. Internal + functions do not have the "global_symbol" attribute, so it does + not need to be updated. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"tir_subroutine": "tir_subroutine_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine_with_new_name(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_simultaneous_replacements(): + """Multiple replacements may be performed simultaneously""" + + before = _get_before_module() + after = before.replace_global_vars( + { + "relax_main": "relax_main_with_new_name", + "relax_subroutine": "relax_subroutine_with_new_name", + "tir_main": "tir_main_with_new_name", + "tir_subroutine": "tir_subroutine_with_new_name", + } + ) + + @I.ir_module + class Expected: + @R.function + def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine_with_new_name(A) + C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main_with_new_name(C, D) + + return D + + @R.function(private=True) + def relax_subroutine_with_new_name( + A: R.Tensor([16], "float32"), + ) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine_with_new_name(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +if __name__ == "__main__": + tvm.testing.main() From 751467e98d0f3acd16d2031e5febef91717b9e98 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 12 Sep 2024 15:32:31 -0700 Subject: [PATCH 138/202] [Relax] Fix BYOC removing existing ext mods (#17353) --- src/relax/transform/run_codegen.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index fe0e73d99e99..af9ed2fffce2 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -79,6 +79,10 @@ class CodeGenRunner : ExprMutator { auto out_mod = builder_->GetContextIRModule(); if (ext_mods.size()) { + if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { + auto old_ext_mods = opt_old_ext_mods.value(); + ext_mods.insert(ext_mods.begin(), old_ext_mods.begin(), old_ext_mods.end()); + } out_mod = WithAttr(out_mod, tvm::attr::kExternalMods, std::move(ext_mods)); } From 37555713a023802ad7926addb37a5a8d43fd991f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 13 Sep 2024 21:29:41 +0900 Subject: [PATCH 139/202] [Relax][PyTorch] Cleanup Neural Network op converters (#17369) * cleanup `_adaptive_avg_pool2d()` * cleanup `addmm()` * cleanup `_avg_pool2d()` * cleanup `_baddbmm()` * cleanup `_conv1d_transpose()` * cleanup `_conv2d_transpose()` * cleanup `_conv1d()` * cleanup `_conv2d()` * cleanup `_conv3d()` * cleanup `_einsum()` * cleanup `_embedding()` * cleanup `_group_norm()` * cleanup `_layer_norm()` * cleanup `_linear()` * cleanup `_max_pool2d()` * cleanup `_scaled_dot_product_attention()` * cleanup `_unbind()` * remove `_matmul_impl()` since we don't use it anymore --- .../tvm/relax/frontend/torch/fx_translator.py | 1526 ++++++++--------- 1 file changed, 755 insertions(+), 771 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7efc2412eaf7..1c4796a533a4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -256,197 +256,30 @@ def call_binary_op(op, lhs, rhs): return convert - ########## Creation ########## - - def _arange(self, node: fx.Node) -> relax.Var: - import torch - - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] - - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] - - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 - - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - - def _empty(self, node: fx.Node) -> relax.Var: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args, dtype)) - - def _inplace_fill(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled - - def _tensor(self, node: fx.Node) -> relax.Var: - dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None - if isinstance(node.args[0], float): - return relax.const(node.args[0], dtype if dtype is not None else "float32") - elif isinstance(node.args[0], int): - return relax.const(node.args[0], dtype if dtype is not None else "int64") - raise ValueError("torch.tensor with value not a float or int is not accepted") - - def _inplace_tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - - mutated = self.block_builder.emit(op(x, k)) - self.env[node.args[0]] = mutated - return mutated - - return convert - - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - - def _ones(self, node: fx.Node) -> relax.Var: - import torch + ########## Neural Network ########## - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _full(self, node: fx.Node) -> relax.Var: - import torch + def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - ########## Statistical ########## - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) - - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - - def _type(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - - def _to(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x - - ########## Linear Algebra ########## - - def _matmul_impl(self, a: relax.Expr, b: relax.Expr): - return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] z = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) res = None if alpha != 0: @@ -463,12 +296,50 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + + def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] a = self.env[node.args[1]] b = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) res = None if alpha != 0: @@ -485,229 +356,73 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) - return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) - - def _unbind(self, node: fx.Node) -> relax.Var: - if len(node.args) == 2: - assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" - dim = node.args[1] - elif "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 - x = self.env[node.args[0]] - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) - ########## Manipulation ########## + if bias is None: + return conv1d_transpose - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - def _expand(self, node: fx.Node) -> relax.Var: + def _conv1d_transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(args[1:]): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - - def _flatten(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - start_dim = module.start_dim - end_dim = module.end_dim - else: - start_dim = node.args[1] if len(node.args) >= 2 else 0 - end_dim = node.args[2] if len(node.args) == 3 else -1 - shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - - def _split(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - split_size = node.args[1] - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _chunk(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - chunks = node.args[1] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 2: - dim = node.args[2] - else: - dim = 0 - return self.block_builder.emit(relax.op.split(x, chunks, dim)) - - def _transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - full_idx = list(range(len(self.shape_of(args[0])))) - full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] - return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - return self.block_builder.emit(relax.op.squeeze(x, dim)) - - def _repeat(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") - - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - - def _index_select(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] - index = self.env[node.args[2]] - return self.block_builder.emit(relax.op.take(x, index, dim)) - - def _masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - return self.block_builder.emit(relax.op.where(mask, values, x)) - - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - output = self.block_builder.emit(relax.op.where(mask, values, x)) - self.env[node.args[0]] = output - return output - - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = None - keepdims = False - - if len(node.args) > 1: - dim = node.args[1] - if len(node.args) > 2: - keepdims = node.args[2] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - if "keepdim" in node.kwargs: - keepdims = node.kwargs["keepdim"] - if "keepdims" in node.kwargs: - keepdims = node.kwargs["keepdims"] - - return self.block_builder.emit(op(x, dim, keepdims)) - - return convert - - ########## Neural Network ########## - def _linear(self, node: fx.Node) -> relax.Var: + def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - bias = None if module.bias is None else self.params[module.bias] - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + bias = self.params.get(module.bias, None) - def _linear_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) - def _conv1d_impl( + def _conv2d_transpose_impl( self, x: relax.Expr, weight: relax.Expr, @@ -717,45 +432,28 @@ def _conv1d_impl( dilation: Optional[Tuple], groups: Optional[Tuple], ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( x, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, - data_layout="NCW", - kernel_layout="OIW", + data_layout="NCHW", + kernel_layout="OIHW", out_dtype="float32", ) ) if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + return conv2d_transpose - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - def _conv1d_functional(self, node: fx.Node) -> relax.Var: + def _conv2d_transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -764,7 +462,7 @@ def _conv1d_functional(self, node: fx.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( + return self._conv2d_transpose_impl( x, weight, bias=bias, @@ -774,7 +472,23 @@ def _conv1d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv1d_transpose_impl( + def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_impl( self, x: relax.Expr, weight: relax.Expr, @@ -784,8 +498,8 @@ def _conv1d_transpose_impl( dilation: Optional[Tuple], groups: Optional[Tuple], ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( x, weight, strides=strides, @@ -799,31 +513,12 @@ def _conv1d_transpose_impl( ) if bias is None: - return conv1d_transpose - + return conv1d assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: + def _conv1d(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -832,7 +527,7 @@ def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( + return self._conv1d_impl( x, weight, bias=bias, @@ -842,6 +537,22 @@ def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + def _conv2d_impl( self, x: relax.Expr, @@ -873,24 +584,6 @@ def _conv2d_impl( return self.block_builder.emit(relax.op.add(conv2d, bias)) def _conv2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv2d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -909,7 +602,23 @@ def _conv2d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv2d_transpose_impl( + def _conv2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv3d_impl( self, x: relax.Expr, weight: relax.Expr, @@ -918,37 +627,53 @@ def _conv2d_transpose_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( x, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", + data_layout="NCDHW", + kernel_layout="OIDHW", out_dtype="float32", ) ) if bias is None: - return conv2d_transpose - + return conv3d assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + bias = self.params.get(module.bias, None) - return self._conv2d_transpose_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -958,182 +683,570 @@ def _conv2d_transpose(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var: + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _embedding_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + return self._embedding_impl(x, weight) + + def _group_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + eps = module.eps + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv3d_impl( + def _linear_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _max_pool2d_impl( self, x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None else stride + return self.block_builder.emit( + relax.op.nn.max_pool2d( x, - weight, - strides=strides, + pool_size=kernel_size, + strides=stride, padding=padding, dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", + ceil_mode=ceil_mode, + layout="NCHW", ) ) - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + def _max_pool2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _max_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + + ########## Creation ########## + + def _arange(self, node: fx.Node) -> relax.Var: + import torch + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.Node) -> relax.Var: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _tensor(self, node: fx.Node) -> relax.Var: + dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") + + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) + + ########## DataType ########## + + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + ########## Manipulation ########## + + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(args[1:]): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - def _conv3d(self, node: fx.Node) -> relax.Var: + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + def _chunk(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 2: + dim = node.args[2] + else: + dim = 0 + return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _conv3d_functional(self, node: fx.Node) -> relax.Var: + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _max_pool2d(self, node: fx.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - dilation = module.dilation - ceil_mode = module.ceil_mode + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - stride = node.args[2] if nargs > 2 else node.kwargs["stride"] - padding = node.args[3] if nargs > 3 else node.kwargs["padding"] - dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] - ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + dim = None + return self.block_builder.emit(relax.op.squeeze(x, dim)) - stride = kernel if stride is None else stride + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - return self.block_builder.emit( - relax.op.nn.max_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - dilation=dilation, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _avg_pool2d(self, node: fx.Node) -> relax.Var: + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - ceil_mode = module.ceil_mode + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - if nargs > 2: - stride = node.args[2] - elif "stride" in node.kwargs.keys(): - stride = node.kwargs["stride"] - else: - stride = None - if nargs > 3: - padding = node.args[3] - elif "padding" in node.kwargs.keys(): - padding = node.kwargs["padding"] - else: - padding = 0 - if nargs > 4: - ceil_mode = node.args[4] - elif "ceil_mode" in node.kwargs.keys(): - ceil_mode = node.kwargs["ceil_mode"] - else: - ceil_mode = False + dim = None + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") - stride = kernel if stride is None else stride + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) + def _index_select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) - def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + def _masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + + def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) + self.env[node.args[0]] = output + return output + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: from torch import fx - def _impl(node: fx.Node) -> relax.Var: - if is_module: - module = self.named_modules[node.target] - x = self.env[node.args[0]] - output_size = module.output_size - else: - x = self.env[node.args[0]] - output_size = node.args[1] - return self.block_builder.emit( - relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") - ) + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = None + keepdims = False + + if len(node.args) > 1: + dim = node.args[1] + if len(node.args) > 2: + keepdims = node.args[2] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + if "keepdim" in node.kwargs: + keepdims = node.kwargs["keepdim"] + if "keepdims" in node.kwargs: + keepdims = node.kwargs["keepdims"] - return _impl + return self.block_builder.emit(op(x, dim, keepdims)) + + return convert + + ########## Neural Network ########## def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1169,115 +1282,6 @@ def _batch_norm_2d(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _layer_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - x = self.env[node.args[0]] - - # functional.layer_norm - if node.target not in self.named_modules: - # static or symbolic - arg = node.args[1] - if isinstance(arg, (immutable_list, tuple)): - value = tuple(arg) - else: - try: - value = self.env[arg] - except TypeError: - value = tuple(arg) - normalized_shape = value - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - gamma = node.kwargs["weight"] - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - else: - gamma = self.env[gamma] - beta = node.kwargs["bias"] - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - else: - beta = self.env[beta] - eps = node.kwargs["eps"] - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - module = self.named_modules[node.target] - - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - dim_num = len(module.normalized_shape) - axes = list(range(-dim_num, 0)) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=module.eps, - ) - ) - - def _group_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - - if module.affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) - beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) - - dim = len(self.shape_of(x)) - return self.block_builder.emit( - relax.op.nn.group_norm( - x, - gamma, - beta, - num_groups=module.num_groups, - channel_axis=1, - axes=list(range(2, dim)), - epsilon=module.eps, - ) - ) - - def _embedding(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, @@ -1387,26 +1391,6 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr: ) ) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - assert ( - len(node.args) <= 4 - ), "Dropout is not supported, and is_causal should be called by kwargs." - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - causal_mask = "TopLeft" if node.kwargs.get("is_causal", False) else None - - if len(node.args) == 4: - mask = self.env[node.args[3]] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in mask.struct_info.dtype, msg - attn = relax.op.nn.attention(query, key, value, bias=mask, causal_mask=causal_mask) - else: - attn = relax.op.nn.attention(query, key, value, causal_mask=causal_mask) - - return self.block_builder.emit(attn) - ########## Others ########## def _sym_size_int(self, node: fx.Node) -> relax.Expr: @@ -1538,20 +1522,20 @@ def create_convert_map(self): nn.Softmax: self._softmax_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network - nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), - nn.AvgPool2d: self._avg_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, + nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d, - nn.Conv1d: self._conv1d, - nn.Conv2d: self._conv2d, - nn.Conv3d: self._conv3d, - nn.ConvTranspose1d: self._conv1d_transpose, - nn.ConvTranspose2d: self._conv2d_transpose, + nn.Conv1d: self._conv1d_module, + nn.Conv2d: self._conv2d_module, + nn.Conv3d: self._conv3d_module, + nn.ConvTranspose1d: self._conv1d_transpose_module, + nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, - nn.GroupNorm: self._group_norm, - nn.LayerNorm: self._layer_norm, - nn.Linear: self._linear, - nn.MaxPool2d: self._max_pool2d, - nn.modules.sparse.Embedding: self._embedding, + nn.GroupNorm: self._group_norm_module, + nn.LayerNorm: self._layer_norm_module, + nn.Linear: self._linear_module, + nn.MaxPool2d: self._max_pool2d_module, + nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation nn.Flatten: self._flatten, ## call_function and call_method @@ -1603,23 +1587,23 @@ def create_convert_map(self): "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), # neural network - "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "adaptive_avg_pool2d": self._adaptive_avg_pool2d, "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose_functional, - "conv_transpose2d": self._conv2d_transpose_functional, - "conv1d": self._conv1d_functional, - "conv2d": self._conv2d_functional, - "conv3d": self._conv3d_functional, + "conv_transpose1d": self._conv1d_transpose, + "conv_transpose2d": self._conv2d_transpose, + "conv1d": self._conv1d, + "conv2d": self._conv2d, + "conv3d": self._conv3d, "cross_entropy": self._cross_entropy, "einsum": self._einsum, "interpolate": self._interpolate, "layer_norm": self._layer_norm, - "linear": self._linear_functional, + "linear": self._linear, "max_pool2d": self._max_pool2d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], From eb011c75642c90c30c8ca139922fdde82034ee88 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 13 Sep 2024 08:17:28 -0500 Subject: [PATCH 140/202] [Bugfix][Relax] Preserve existing DataflowBlock in ConvertToDataflow (#17148) The `relax.transform.ConvertToDataflow` identifies portions of a Relax function that satisfy the requirements of a `relax::DataflowBlock`, and converts those portions to a new `DataflowBlock`, provided they are at least some minimum number of operations. Prior to this commit, if a function contained a region that would be converted to a `DataflowBlock`, but also contains existing `DataflowBlock`s that were smaller than the size required for creating a `DataflowBlock`, those existing blocks would be erroneously converted to non-dataflow. This commit updates the `ConvertToDataflow` pass to preserve all existing `DataflowBlock` present in the input. --- src/relax/transform/convert_dataflow.cc | 117 ++++++++++-------- .../relax/test_transform_convert_dataflow.py | 106 ++++++++++++++++ 2 files changed, 173 insertions(+), 50 deletions(-) diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index b927307c2e0e..528a466a9bb3 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -28,6 +28,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -39,10 +41,59 @@ class DataflowBlockExtractor : public ExprMutator { Array new_blocks; Expr new_body = VisitExpr(seq->body); bool changed = !new_body.same_as(seq->body); - bool dataflow_streak = false; - Array dataflow_bindings; + + // Accumulated bindings that are not going to be added to a + // DataflowBlock, either because they would be illegal within a + // DataflowBlock, or because there were insufficient bindings to + // make a dataflowblock. Because these bindings occur prior to + // `dataflow_bindings`, this array may only be accumulated into + // when `dataflow_bindings` is empty. Array non_dataflow_bindings; + // Current bindings that may legally be added to a DataflowBlock. + Array dataflow_bindings; + + // If present, a DataflowBlock whose bindings are currently in + // `dataflow_bindings`. Used to propagate DataflowBlock to the + // output, even if it doesn't meet the minimum size. + Optional input_dataflow_block; + + // Handle any bindings currently in `dataflow_bindings`. These + // are either pushed to their own block, or to the end of + // `non_dataflow_bindings`, depending on whether the bindings meet + // the minimum size requirement. + auto push_dataflow_bindings = [&]() { + if (dataflow_bindings.empty()) { + // No Dataflow bindings, so no action required. + return; + } + if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) { + // The df block is below the minimum length, and no input + // DataflowBlock needs to be preserved. Combine the blocks + // and reset the dataflow collection. + + non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), + dataflow_bindings.end()); + + } else { + // A new DataflowBlock can be generated, with bindings that + // occur after the non-dataflow bindings. + new_blocks.push_back(BindingBlock(non_dataflow_bindings)); + new_blocks.push_back(DataflowBlock(dataflow_bindings)); + non_dataflow_bindings = {}; + + // Making a dataflow block doesn't imply that the function was + // changed. A change requires that this either be a new + // dataflow block, or have additional dataflow bindings in the + // current block. + changed = changed || !input_dataflow_block.defined() || + input_dataflow_block.value()->bindings.size() != dataflow_bindings.size(); + } + + dataflow_bindings = {}; + input_dataflow_block = NullOpt; + }; + for (auto block : seq->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); changed = changed || !new_block.same_as(block); @@ -50,74 +101,40 @@ class DataflowBlockExtractor : public ExprMutator { // For an existing dataflow block, we add to the current streak // or start a new streak in case there will be more dataflow operations // coming up - if (new_block.as()) { - if (!dataflow_streak) { - dataflow_streak = true; - } + if (auto dataflow_block = new_block.as()) { dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(), new_block->bindings.end()); + input_dataflow_block = dataflow_block; continue; } // for a binding block, attempt to extract dataflow blocks inside auto binding_block = Downcast(new_block); - for (size_t i = 0; i < binding_block->bindings.size(); i++) { - auto binding = binding_block->bindings[i]; + for (const auto& binding : binding_block->bindings) { Expr value = GetBoundValue(binding); // dataflow values: not an if node and not an impure call bool is_dataflow = (!value.as()) && (!(value.as() && IsImpureCall(Downcast(value)))); - if (!dataflow_streak) { - // we can start a dataflow streak - if (is_dataflow) { - dataflow_streak = true; - dataflow_bindings = {binding}; - } else { - non_dataflow_bindings.push_back(binding); - } + if (is_dataflow) { + // extend the streak + dataflow_bindings.push_back(binding); } else { - if (is_dataflow) { - // extend the streak - dataflow_bindings.push_back(binding); - } else { - // this is the end of the streak - dataflow_streak = false; - - // if the df block is below the minimum length, combine the blocks - // and reset the dataflow collection - if (dataflow_bindings.size() < min_size_) { - non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), - dataflow_bindings.end()); - dataflow_bindings = {}; - } else { - // otherwise insert both collections - changed = true; - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - new_blocks.push_back(DataflowBlock(dataflow_bindings)); - non_dataflow_bindings = {}; - dataflow_bindings = {}; - } - non_dataflow_bindings.push_back(binding); - } + // End the streak, if one currently exists. + push_dataflow_bindings(); + non_dataflow_bindings.push_back(binding); } } } // handle any remaining bindings - if (dataflow_bindings.size() < min_size_) { - non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), - dataflow_bindings.end()); - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - } else { - changed = true; - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - new_blocks.push_back(DataflowBlock(dataflow_bindings)); - } + push_dataflow_bindings(); + new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - if (!changed) { + if (changed) { + return SeqExpr(new_blocks, new_body); + } else { return GetRef(seq); } - return SeqExpr(new_blocks, new_body); } private: diff --git a/tests/python/relax/test_transform_convert_dataflow.py b/tests/python/relax/test_transform_convert_dataflow.py index 8a926cd4aedc..ab78ec0b3bc7 100644 --- a/tests/python/relax/test_transform_convert_dataflow.py +++ b/tests/python/relax/test_transform_convert_dataflow.py @@ -489,5 +489,111 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: return v +class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare): + """Preserve existing DataflowBlocks + + This is a regression test. In previous implementations, a + DataflowBlock in the input, without enough bindings to become a + new dataflow block, could be accidentally ommitted. + + This test is identical to + `TestPreserveExistingDataflowBlocksAtEnd`, except that the + existing dataflow block is at the beginning of the function. + + """ + + @I.ir_module + class Before: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This DataflowBlock is below the minimum size for a new + # block, but already exists in the input IRModule. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + R.print(format="impure_function") + + # This sequence is large enough that it may be converted + # to a DataflowBlock. + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + + return (A1, B3) + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This dataflow block should be preserved in the output. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + R.print(format="impure_function") + + with R.dataflow(): + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + R.output(B3) + + return (A1, B3) + + +class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare): + """Preserve existing DataflowBlocks + + This is a regression test. In previous implementations, a + DataflowBlock in the input, without enough bindings to become a + new dataflow block, could be accidentally ommitted. + + This test is identical to + `TestPreserveExistingDataflowBlocksAtBeginning`, except that the + existing dataflow block is at the end of the function. + + """ + + @I.ir_module + class Before: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This sequence is large enough that it may be converted + # to a DataflowBlock. + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + + R.print(format="impure_function") + + # This DataflowBlock is below the minimum size for a new + # block, but already exists in the input IRModule. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + return (A1, B3) + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + with R.dataflow(): + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + R.output(B3) + + R.print(format="impure_function") + + # This dataflow block should be preserved in the output. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + return (A1, B3) + + if __name__ == "__main__": tvm.testing.main() From cea4c850221cbbb757f753408274bdcfbd9bc648 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 14 Sep 2024 07:03:28 -0400 Subject: [PATCH 141/202] [WEBGPU] Update runtime to remove deprecated API (#17371) This PR updates webgpu runtime code to remove deprecated API. unblocks the CI. --- web/src/webgpu.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 284d6d3887d9..d3d431cf1f70 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -116,7 +116,7 @@ export async function detectGPUDevice(): Promise Date: Sat, 14 Sep 2024 21:16:07 +0800 Subject: [PATCH 142/202] [FIX] fix bug when normalize iter with different lower bounds (#17360) If an iter has been normalized with a lower bound, and then try to normalize with a new lower bound, the iter_min need to be updated only when the new lower bound is smaller than the original one. Co-authored-by: liujiaqiang --- src/arith/iter_affine_map.cc | 2 +- .../arith/test_arith_iter_affine_map.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 77b20fcdf203..d24c278f1048 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -696,7 +696,7 @@ class IterMapRewriter : public ExprMutator { // the delta of iter_min when it is updated when the lower bound predicate is present PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0); if (predicate_induced_min.defined()) { - iter_min_delta = predicate_induced_min.value() - iter_min; + iter_min_delta = max(predicate_induced_min.value(), iter_min) - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index f0e6f05adfad..f34dce5c86fd 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -346,6 +346,27 @@ def test_predicate(): predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) + # constraint with differnent lower bound + assert_iter_sum_pattern( + { + (i * 16 + j) // 23 * 8 + + (i * 16 + j) % 23 + - 15: ( + 64, + 0, + 1, + (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)), + ) + }, + var_dom([(i, 12), (j, 16)]), + predicate=tvm.tir.And( + tvm.tir.And( + i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23) + ), + tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23), + ), + ) + # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) From 4bc61a14452cdae09231f1085d40a4b04fbe1f75 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Sat, 14 Sep 2024 23:07:06 -0400 Subject: [PATCH 143/202] [Relax][Transform] Add SelectNode handling in SymbolicMatcher (#17368) This PR added support for handling SelectNode in the SymbolicMatcher class by modifying the VisitExpr_ function to match the true_value and false_value expressions between the current SelectNode and the other expression. If the other expression is not a SelectNode, the matching condition is updated to ensure the current SelectNode expression is equivalent to the other expression. --- src/relax/transform/fuse_tir.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 612e1459c826..fe247645dc24 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -139,6 +139,16 @@ class SymbolicMatcher : ExprFunctor(); + if (rhs) { + VisitExpr(op->true_value, rhs->true_value); + VisitExpr(op->false_value, rhs->false_value); + } else { + must_prove_ = must_prove_ && (GetRef(op) == other); + } + } + arith::Analyzer* analyzer_; Map* var_remap_; PrimExpr must_prove_ = Bool(true); From 48d661c0ee277a6594a845423a384b5e1a743350 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 15 Sep 2024 22:07:58 +0900 Subject: [PATCH 144/202] [Relax][PyTorch] Cleanup Statistical, Search and DataType op converters (#17372) * cleanup `_mean()` * cleanup `_sum()` * cleanup `_argmax_argmin()` * cleanup datatype ops --- .../tvm/relax/frontend/torch/fx_translator.py | 123 ++++++++---------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1c4796a533a4..4dc49d20ff36 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,61 @@ def _unbind(self, node: fx.Node) -> relax.Var: ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(op(x, dim, keepdim)) + + return convert + + ########## DataType ########## + + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + def _type(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1022,48 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - ########## Statistical ########## - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) - - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - - def _type(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - - def _to(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = None - keepdims = False - - if len(node.args) > 1: - dim = node.args[1] - if len(node.args) > 2: - keepdims = node.args[2] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - if "keepdim" in node.kwargs: - keepdims = node.kwargs["keepdim"] - if "keepdims" in node.kwargs: - keepdims = node.kwargs["keepdims"] - - return self.block_builder.emit(op(x, dim, keepdims)) - - return convert - ########## Neural Network ########## def _softmax(self, node: fx.Node) -> relax.Var: From 11198f6e40a9999bb665d5bc1a7583471cbc0b06 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sun, 15 Sep 2024 22:46:31 +0800 Subject: [PATCH 145/202] [MSC][Refactor] Support dynamic shape (#17351) * support prims for tir.Var * minor fix * bug fix for pruner --- .../tvm/contrib/msc/core/codegen/codegen.py | 7 +- .../contrib/msc/core/frontend/translate.py | 38 + python/tvm/contrib/msc/core/ir/graph.py | 93 +- .../contrib/msc/core/tools/prune/pruner.py | 7 +- python/tvm/contrib/msc/core/tools/tool.py | 3 + .../msc/framework/torch/frontend/translate.py | 4 +- python/tvm/contrib/msc/pipeline/pipeline.py | 12 +- python/tvm/contrib/msc/pipeline/utils.py | 37 +- python/tvm/contrib/msc/pipeline/wrapper.py | 3 + src/contrib/msc/core/codegen/base_codegen.h | 34 +- src/contrib/msc/core/codegen/codegen_utils.cc | 28 +- src/contrib/msc/core/codegen/codegen_utils.h | 33 +- src/contrib/msc/core/codegen/cpp_codegen.h | 14 + src/contrib/msc/core/codegen/py_codegen.h | 14 + src/contrib/msc/core/ir/graph.cc | 185 ++- src/contrib/msc/core/ir/graph.h | 156 +- src/contrib/msc/core/ir/graph_builder.cc | 151 +- src/contrib/msc/core/ir/graph_builder.h | 12 + .../msc/core/transform/layout_utils.cc | 51 +- src/contrib/msc/core/transform/layout_utils.h | 6 + .../msc/core/transform/set_expr_layout.cc | 440 +++--- .../msc/framework/tensorflow/codegen.cc | 3 +- src/contrib/msc/framework/tensorrt/codegen.cc | 3 +- src/contrib/msc/framework/torch/codegen.cc | 3 +- .../msc/framework/torch/torch_opcode.cc | 12 +- .../msc/framework/torch/torch_opcode.h | 6 +- src/contrib/msc/framework/tvm/codegen.cc | 13 +- src/contrib/msc/framework/tvm/codegen.h | 3 + src/contrib/msc/framework/tvm/relax_opcode.cc | 8 +- .../contrib/test_msc/test_graph_build.py | 1362 +++++++++++------ .../python/contrib/test_msc/test_pipeline.py | 6 +- tests/python/contrib/test_msc/test_runner.py | 30 +- tests/python/contrib/test_msc/test_tools.py | 4 +- 33 files changed, 1939 insertions(+), 842 deletions(-) diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index c2711231f400..888f1bad4ebe 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -180,9 +180,10 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None: def _to_var(tensor: MSCTensor): v_name = tensor.alias if use_alias else graph.find_producer(tensor).name - return tvm.relax.Var( - v_name, tvm.relax.TensorStructInfo(tensor.get_shape(), tensor.dtype_name) - ) + dims = [ + d if isinstance(d, int) else tvm.tir.Var(d, "int64") for d in tensor.get_shape(True) + ] + return tvm.relax.Var(v_name, tvm.relax.TensorStructInfo(dims, tensor.dtype_name)) def _save_weights(folder: msc_utils.MSCDirectory): if weights: diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index cea021ade331..8e9bb0cf00d7 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -31,6 +31,44 @@ from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor +def normalize_inputs(inputs: List[tuple]) -> List[tuple]: + """Normalize the inputs info + + Parameters + ---------- + inputs: list of + The inputs info. + + Returns + ------- + inputs: list of + The normalized inputs info. + """ + + recorded_vars = {} + + def _normalize_input(inp): + def _normalize(info): + if not isinstance(info, (tuple, list)): + return info + dims = [] + for dim in info: + if isinstance(dim, int): + dims.append(dim) + elif dim in recorded_vars: + dims.append(recorded_vars[dim]) + elif isinstance(dim, str): + recorded_vars[dim] = tvm.tir.Var(dim, "int64") + dims.append(recorded_vars[dim]) + else: + raise TypeError("Unexpected dim {} in shape {}".format(dim, info)) + return dims + + return [_normalize(i) for i in inp] + + return [_normalize_input(inp) for inp in inputs] + + def normalize_weights( t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph ) -> Dict[str, tvm.nd.array]: diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 19a16a375b7a..172f40e06a31 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -41,6 +41,8 @@ class MSCTensor(Object): The shape of the tensor. alias: string The alias of the tensor. + prims: list + The prims of the tensor. """ def __init__( @@ -50,15 +52,31 @@ def __init__( layout: str, shape: List[int], alias: Optional[str] = None, + prims: List[str] = None, ): if not isinstance(dtype, tvm.DataType): dtype = tvm.DataType(dtype) self.__init_handle_by_constructor__( - _ffi_api.MSCTensor, name, dtype, layout, shape, alias or "" + _ffi_api.MSCTensor, name, dtype, layout, shape, alias or "", prims or [] ) - def get_shape(self) -> List[int]: - return [int(i) for i in self.shape] + def get_shape(self, with_prims: bool = False) -> List[Union[int, str]]: + """Get shape of the tensor + + Parameters + ------- + with_prims: bool + Whether get shape with prims. + + Returns + ------- + shape: list + The shape of tensor. + """ + + if not self.prims or not with_prims: + return [int(i) for i in self.shape] + return [int(p) if p.isdigit() else p for p in self.prims] def get_size(self) -> int: return int(_ffi_api.MSCTensorGetSize(self)) @@ -98,7 +116,7 @@ def equal(self, other: Object) -> bool: if not isinstance(other, MSCTensor): return False - if self.get_shape() != other.get_shape(): + if self.get_shape(True) != other.get_shape(True): return False if self.dtype != other.dtype: return False @@ -124,7 +142,7 @@ def inspect(self) -> dict: The tensor description in json format. """ - tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype": self.dtype_name} + tensor_des = {"name": self.alias, "shape": self.get_shape(True), "dtype": self.dtype_name} tensor_des["layout"] = self.layout.name if self.layout else "" return tensor_des @@ -405,6 +423,30 @@ def equal(self, other: BaseJoint) -> bool: return msc_utils.dict_equal(self.get_attrs(), other.get_attrs()) +@tvm._ffi.register_object("msc.core.MSCPrim") +class MSCPrim(BaseJoint): + """Prim in MSCGraph + + Parameters + ---------- + index: int + The index of the prim. + name: string + The name of the prim. + optype: string + The optype of the prim. + attrs: dict + The attributes of the node. + parents: list + The parents of the prim. + """ + + def __init__( + self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint] + ): + self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents) + + @tvm._ffi.register_object("msc.core.WeightJoint") class WeightJoint(BaseJoint): """Node in WeightGraph @@ -586,6 +628,22 @@ def find_node(self, name: str) -> MSCJoint: return _ffi_api.MSCGraphFindNode(self, name) + def find_prim(self, name: str) -> MSCPrim: + """Find prim by name. + + Parameters + ---------- + name: string + The name of the prim. + + Returns + ------- + prim: MSCPrim + The found prim. + """ + + return _ffi_api.MSCGraphFindPrim(self, name) + def has_tensor(self, name: str) -> bool: """Check if tensor in the graph. @@ -679,6 +737,18 @@ def get_nodes(self) -> Iterable[MSCJoint]: for n in self.node_names: yield self.find_node(n) + def get_prims(self) -> Iterable[MSCPrim]: + """Get all the prims in the graph. + + Returns + ------- + prims: generator + The generator of prims. + """ + + for n in self.prim_names: + yield self.find_prim(n) + def get_weights(self) -> Iterable[MSCTensor]: """Get all the weights in the graph. @@ -789,11 +859,16 @@ def inspect(self) -> dict: "nodes": {"total": 0}, } for node in self.get_nodes(): + graph_des["nodes"].setdefault(node.optype, 0) graph_des["nodes"]["total"] += 1 - if node.optype not in graph_des["nodes"]: - graph_des["nodes"][node.optype] = 1 - else: - graph_des["nodes"][node.optype] += 1 + graph_des["nodes"][node.optype] += 1 + prims = {"total": 0} + for prim in self.get_prims(): + prims.setdefault(prim.optype, 0) + prims["total"] += 1 + prims[prim.optype] += 1 + if prims["total"] > 0: + graph_des["prims"] = prims return graph_des @classmethod diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 90273e25416b..a008100be252 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]): def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): shape = tensor.get_shape() if channel_axis is None: - channel_axis = tensor.layout_of("C") + if self.has_w_node(tensor.name): + w_node = self.find_w_node(tensor.name) + _, channel_axis = self._get_io_axes(w_node) + else: + channel_axis = tensor.layout_of("C") + assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor) shape[channel_axis] = dim return _prune_by_shape(tensor, shape) diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 626ae312bcf4..06a16f2bbe49 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") if in_axis >= 0 and out_axis >= 0: return in_axis, out_axis + if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0: + io_axis = 1 - w_node.weight.layout_of("N") + return io_axis, io_axis if w_node.weight.layout_of("C") >= 0: return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index 2509f1abfcbe..c8c2844c2859 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -22,9 +22,8 @@ import torch import tvm from tvm.relax.frontend.torch import from_fx - from tvm.contrib.msc.core.ir.graph import MSCGraph -from tvm.contrib.msc.core.frontend import from_relax +from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs from tvm.contrib.msc.core.codegen import relay_to_relax @@ -104,6 +103,7 @@ def from_torch( """ if via_relax: + input_info = normalize_inputs(input_info) graph_model, params = torch.fx.symbolic_trace(model), None with torch.no_grad(): relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py b/python/tvm/contrib/msc/pipeline/pipeline.py index f02503a113ca..e003f692241c 100644 --- a/python/tvm/contrib/msc/pipeline/pipeline.py +++ b/python/tvm/contrib/msc/pipeline/pipeline.py @@ -676,10 +676,20 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: max_batch = config.get("max_batch", 5) def get_random(): + def _to_data(inp): + shape = [1 if isinstance(d, str) else d for d in inp[1]] + return np.random.rand(*shape).astype(inp[2]) + for _ in range(max_batch): - yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} + yield {i[0]: _to_data(i) for i in self._config["inputs"]} loader, source_type = get_random, "random" + elif isinstance(source_loader, dict): + + def load_data(): + return [source_loader] + + loader, source_type = load_data, "dict" elif msc_utils.is_io_dataset(source_loader): max_batch = config.get("max_batch", -1) diff --git a/python/tvm/contrib/msc/pipeline/utils.py b/python/tvm/contrib/msc/pipeline/utils.py index e4d91ee14b62..c6689e1f0091 100644 --- a/python/tvm/contrib/msc/pipeline/utils.py +++ b/python/tvm/contrib/msc/pipeline/utils.py @@ -16,6 +16,7 @@ # under the License. """tvm.contrib.msc.pipeline.config""" +import copy from typing import List, Union, Dict, Tuple from tvm.contrib.msc.core.tools import ToolType @@ -129,6 +130,7 @@ def create_config( dataset: Dict[str, dict] = None, tools: List[Tuple[str, Union[dict, str]]] = None, dynamic: bool = False, + run_config: Dict[str, dict] = None, skip_config: Dict[str, str] = None, **extra_config, ) -> dict: @@ -160,11 +162,13 @@ def create_config( The extra config. """ + all_stages = [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE] baseline_type = baseline_type or model_type optimize_type = optimize_type or baseline_type compile_type = compile_type or optimize_type tools = tools or [] tools = [config_tool(t_type, t_config) for t_type, t_config in tools] + extra_config = extra_config or {} # basic config config = { "model_type": model_type, @@ -194,27 +198,34 @@ def create_config( "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, } + # update run config + if run_config: + if "all" in run_config: + all_config = run_config.pop("all") + run_config.update({s: copy.deepcopy(all_config) for s in all_stages}) + for stage, r_config in run_config.items(): + extra_config.setdefault(stage, {}).setdefault("run_config", {}).update(r_config) + # update config if extra_config: config = msc_utils.update_dict(config, extra_config) # skip stages - skip_config = skip_config or {} - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in config: - continue - for key in ["all", stage]: - if key not in skip_config: + if skip_config: + if "all" in run_config: + all_config = skip_config.pop("all") + skip_config.update({s: copy.deepcopy(all_config) for s in all_stages}) + for stage, s_type in skip_config.items(): + if stage not in config: continue - if skip_config[key] == "stage": + if s_type == "stage": config.pop(stage) - elif skip_config[key] == "profile": + elif s_type == "profile": config[stage].pop("profile") - elif skip_config[key] == "check": - config[stage]["profile"].pop("check") - elif skip_config[key] == "benchmark": + elif s_type == "check": + config[stage]["profile"]["check"]["err_rate"] = -1 + elif s_type == "benchmark": config[stage]["profile"].pop("benchmark") else: - raise TypeError("Unexpected skip type " + str(skip_config[key])) - + raise TypeError("Unexpected skip type " + str(s_type)) return config diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py index 1332b3c79115..91862c794027 100644 --- a/python/tvm/contrib/msc/pipeline/wrapper.py +++ b/python/tvm/contrib/msc/pipeline/wrapper.py @@ -240,6 +240,9 @@ class TorchWrapper(BaseWrapper): """Wrapper of torch models""" def __call__(self, *inputs): + return self.forward(*inputs) + + def forward(self, *inputs): framework = self._get_framework() if framework != MSCFramework.TORCH: inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs] diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index acaac896a153..f582f6416d93 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -58,9 +58,11 @@ class BaseOpCode { virtual ~BaseOpCode() = default; /*! \brief Config the BaseOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config) { + void Config(const MSCJoint& node, const std::shared_ptr config, + const Map& prims) { node_ = node; config_ = config; + prims_ = prims; } /*! \brief Get docs for the node*/ @@ -158,6 +160,13 @@ class BaseCodeGen { } } + virtual void Init() { + // define prims + for (const auto& p_name : this->graph()->prim_names) { + prims_.Set(p_name, this->DescribePrim(this->graph()->FindPrim(p_name))); + } + } + virtual ~BaseCodeGen() = default; /*! \brief Get sources*/ @@ -211,6 +220,29 @@ class BaseCodeGen { /*! \brief Get the docs for the op*/ virtual const Array GetOpCodes(const MSCJoint& node) = 0; + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + if (prim->optype == "Int") { + return prim->GetTypeAttr("value"); + } + if (prim->optype == "shape") { + const auto& producer = this->graph()->FindNode(prim->GetTypeAttr("producer")); + int out_idx = prim->GetTypeAttr("out_idx"); + const auto& dim = prim->GetTypeAttr("dim"); + return this->IdxOutputBase(producer, out_idx) + ".shape[" + dim + "]"; + } + // binary ops + DESCRIBE_PRIM_BINARY("Add", "+", false) + DESCRIBE_PRIM_BINARY("Sub", "-", false) + DESCRIBE_PRIM_BINARY("Mul", "*", false) + DESCRIBE_PRIM_BINARY("Divide", "/", false) + DESCRIBE_PRIM_BINARY("LT", "<", false) + DESCRIBE_PRIM_BINARY("LE", "<=", false) + DESCRIBE_PRIM_BINARY("GT", ">", false) + DESCRIBE_PRIM_BINARY("GE", ">=", false) + LOG_FATAL << "Unexpected prim " << prim; + } + /*! \brief Get the graph*/ const MSCGraph graph() const { return graph_; } diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc b/src/contrib/msc/core/codegen/codegen_utils.cc index 44626debe1d8..741b729bd015 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.cc +++ b/src/contrib/msc/core/codegen/codegen_utils.cc @@ -54,13 +54,37 @@ const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype, return wtype + "_" + std::to_string(node->index) + suffix; } -const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix) { +const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, + const Map& prims) { + Array dims; + if (tensor->prims.size() == 0) { + for (size_t i = 0; i < tensor->Ndim(); i++) { + dims.push_back(StringUtils::ToString(tensor->DimAt(i))); + } + return dims; + } + for (size_t i = 0; i < tensor->Ndim(); i++) { + const auto& prim = tensor->PrimAt(i); + dims.push_back(prims.count(prim) ? prims[prim] : prim); + } + return dims; +} + +const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix, + const Map& prims) { String comment = node->name + "(" + node->optype + "): <"; for (size_t i = 0; i < node->inputs.size(); i++) { comment = comment + IdxInput(node, prefix, i) + (i == node->inputs.size() - 1 ? "> -> <" : ","); } for (size_t i = 0; i < node->outputs.size(); i++) { - comment = comment + IdxOutput(node, prefix, i) + (i == node->outputs.size() - 1 ? ">" : ","); + const auto& t_output = node->OutputAt(i); + const auto& t_prims = GetPrims(t_output, prims); + comment = comment + IdxOutput(node, prefix, i) + "|" + StringUtils::Join(t_prims, ":"); + comment = comment + "|" + t_output->DTypeName(); + if (t_output->layout.defined()) { + comment = comment + "|" + t_output->layout->name; + } + comment = comment + (i == node->outputs.size() - 1 ? ">" : ", "); } return comment; } diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 1af8df5ac1a4..abdb91b4703f 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -76,12 +76,23 @@ using namespace tvm::script::printer; LOG(FATAL) << "Do not support key " << key; \ } +#define DESCRIBE_PRIM_BINARY(OpType, Symbol, AsFunc) \ + if (prim->optype == OpType) { \ + if (AsFunc) { \ + return std::string(Symbol) + "(" + this->DescribePrim(prim->ParentAt(0)) + "," + \ + this->DescribePrim(prim->ParentAt(1)) + ")"; \ + } \ + return "(" + this->DescribePrim(prim->ParentAt(0)) + Symbol + \ + this->DescribePrim(prim->ParentAt(1)) + ")"; \ + } + #define CODEGEN_MEMBERS \ public: \ virtual const String DType(const DataType& dtype) { return runtime::DLDataType2String(dtype); } \ \ protected: \ const std::shared_ptr config() { return config_; } \ + const Map prims() { return prims_; } \ const String IdxNodeBase(const MSCJoint& node) { \ return helper_.IdxNodeBase(node, config()->prefix, ""); \ } \ @@ -95,13 +106,19 @@ using namespace tvm::script::printer; const String IdxWeightBase(const MSCJoint& node, const String& wtype, bool process = true) { \ return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ } \ - const String Comment(const MSCJoint& node) { return helper_.Comment(node, config()->prefix); } \ + const Array GetPrims(const MSCTensor& tensor) { \ + return CodeGenUtils::GetPrims(tensor, prims_); \ + } \ + const String Comment(const MSCJoint& node) { \ + return helper_.Comment(node, config()->prefix, prims_); \ + } \ int CompareVersion(size_t major, size_t minor, size_t patch) { \ return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ } \ \ private: \ std::shared_ptr config_; \ + Map prims_; \ HelperType helper_; /*! @@ -137,11 +154,18 @@ class CodeGenUtils { TVM_DLL static const String IdxWeight(const MSCJoint& node, const String& wtype, const String& suffix = ""); + /*! + * \brief Infer prims of tensor. + * \return The prims. + */ + TVM_DLL static const Array GetPrims(const MSCTensor& tensor, + const Map& prims); /*! * \brief Get comment of a node. * \return The String. */ - TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix); + TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix, + const Map& prims); }; /*! @@ -180,8 +204,9 @@ class BaseCodeGenHelper { const String& suffix = "", bool process = false) { return CodeGenUtils::IdxWeight(node, wtype, suffix + GetSuffix(node, process)); } - virtual const String Comment(const MSCJoint& node, const String& prefix = "") { - return CodeGenUtils::CommentNode(node, prefix); + virtual const String Comment(const MSCJoint& node, const String& prefix = "", + const Map& prims = Map()) { + return CodeGenUtils::CommentNode(node, prefix, prims); } }; diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 2c07aeb4c741..81b7d1e871a2 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -95,6 +95,20 @@ class CppCodeGen : public BaseCodeGen { } protected: + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + // binary ops + DESCRIBE_PRIM_BINARY("Min", "std::min", true) + DESCRIBE_PRIM_BINARY("Max", "std::max", true) + // special + if (prim->optype == "if_then_else") { + return "(" + this->DescribePrim(prim->ParentAt(0)) + "?" + + this->DescribePrim(prim->ParentAt(1)) + ":" + this->DescribePrim(prim->ParentAt(2)) + + ")"; + } + return BaseCodeGen::DescribePrim(prim); + } + /*! \brief Stack the docs for the node*/ virtual void CodeGenNode(const MSCJoint& node, bool use_tools) { this->stack_.comment(this->Comment(node)); diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index e1ceb716a278..c1ecded61df1 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -82,6 +82,20 @@ class PyCodeGen : public BaseCodeGen { } protected: + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + // binary ops + DESCRIBE_PRIM_BINARY("Min", "min", true) + DESCRIBE_PRIM_BINARY("Max", "max", true) + // special + if (prim->optype == "if_then_else") { + return "(" + this->DescribePrim(prim->ParentAt(1)) + " if " + + this->DescribePrim(prim->ParentAt(0)) + " else " + + this->DescribePrim(prim->ParentAt(2)) + ")"; + } + return BaseCodeGen::DescribePrim(prim); + } + /*! \brief Stack the docs for the header*/ virtual void CodeGenHeader() { this->stack_.line("import os") diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index ca1bff09725f..ae42537a4ce1 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -35,13 +35,14 @@ namespace contrib { namespace msc { MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias) { + const Array& shape, const String& alias, const Array& prims) { ObjectPtr n = make_object(); n->name = std::move(name); n->alias = std::move(alias); n->dtype = std::move(dtype); n->shape = std::move(shape); n->layout = tvm::tir::Layout(layout); + n->prims = prims; data_ = std::move(n); } @@ -68,6 +69,9 @@ const JsonMSCTensor MSCTensorNode::ToJson() const { for (const auto& s : shape) { j_tensor.shape.push_back(s->value); } + for (const auto& p : prims) { + j_tensor.prims.push_back(p); + } return j_tensor; } @@ -81,6 +85,9 @@ void MSCTensorNode::FromJson(const JsonMSCTensor& j_tensor) { for (const auto& s : j_tensor.shape) { shape.push_back(s); } + for (const auto& p : j_tensor.prims) { + prims.push_back(p); + } } void MSCTensorNode::FromJson(const std::string& json_str) { @@ -103,6 +110,17 @@ const Integer MSCTensorNode::DimAt(const String& axis) const { return DimAt(index); } +const String MSCTensorNode::PrimAt(int index) const { + if (prims.size() == 0) { + return ""; + } + return prims[CommonUtils::GetIndex(index, Ndim())]; +} + +const String MSCTensorNode::PrimAt(const String& axis) const { + return PrimAt(layout.IndexOf(tvm::tir::LayoutAxis::Get(axis))); +} + int32_t MSCTensorNode::LayoutOf(const String& axis) const { return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); } @@ -498,6 +516,76 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor return ProducerAndIdxOf(input->name); } +MSCPrim::MSCPrim(int index, const String& name, const String& optype, + const Array& parents, const Map& attrs) { + ObjectPtr n = make_object(); + n->index = index; + n->name = std::move(name); + n->optype = std::move(optype); + n->attrs = std::move(attrs); + for (const auto& p : parents) { + n->parents.push_back(p); + } + data_ = std::move(n); +} + +MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const Map& prims) { + ObjectPtr n = make_object(); + n->FromJson(j_prim, prims); + data_ = std::move(n); +} + +MSCPrim::MSCPrim(const std::string& json_str, const Map& prims) { + ObjectPtr n = make_object(); + n->FromJson(json_str, prims); + data_ = std::move(n); +} + +const JsonMSCPrim MSCPrimNode::ToJson() const { + JsonMSCPrim j_prim; + j_prim.index = index; + j_prim.name = name; + j_prim.optype = optype; + for (const auto& pair : attrs) { + j_prim.attrs[pair.first] = pair.second; + } + for (const auto& p : parents) { + j_prim.parents.push_back(Downcast(p)->name); + } + return j_prim; +} + +void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { + index = j_prim.index; + name = j_prim.name; + optype = j_prim.optype; + for (const auto& pair : j_prim.attrs) { + attrs.Set(pair.first, pair.second); + } + for (const auto& p_name : j_prim.parents) { + ICHECK(prims.count(p_name)) << "Can not find parent " << p_name; + parents.push_back(prims[p_name]); + } +} + +void MSCPrimNode::FromJson(const std::string& json_str, const Map& prims) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonMSCPrim j_prim; + reader.Read(&j_prim); + FromJson(j_prim, prims); +} + +const MSCPrim MSCPrimNode::ParentAt(int index) const { + size_t v_index = CommonUtils::GetIndex(index, parents.size()); + return Downcast(parents[v_index]); +} + +const MSCPrim MSCPrimNode::ChildAt(int index) const { + size_t v_index = CommonUtils::GetIndex(index, children.size()); + return Downcast(children[v_index]); +} + WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, const Array parents, const Map& attrs, @@ -587,7 +675,8 @@ const bool BaseGraphNode::HasNode(const String& name) const { } MSCGraph::MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names) { + const Array& input_names, const Array& output_names, + const Array& prims) { ObjectPtr n = make_object(); n->name = std::move(name); for (const auto& node : nodes) { @@ -596,6 +685,10 @@ MSCGraph::MSCGraph(const String& name, const Array& nodes, } n->input_names = std::move(input_names); n->output_names = std::move(output_names); + for (const auto& prim : prims) { + n->prim_names.push_back(prim->name); + n->prims.Set(prim->name, prim); + } n->AnalysisGraph(); data_ = std::move(n); } @@ -625,6 +718,10 @@ const JsonMSCGraph MSCGraphNode::ToJson() const { const auto& node = FindNode(n); j_graph.nodes.push_back(node->ToJson()); } + for (const auto& n : prim_names) { + const auto& prim = FindPrim(n); + j_graph.prims.push_back(prim->ToJson()); + } return j_graph; } @@ -646,6 +743,16 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { node_names.push_back(node->name); nodes.Set(node->name, node); } + Map loaded_prims; + for (const auto& n : j_graph.prims) { + const auto& prim = MSCPrim(n, loaded_prims); + loaded_prims.Set(prim->name, prim); + for (const auto& p : prim->parents) { + Downcast(p)->AddChild(prim); + } + prim_names.push_back(prim->name); + prims.Set(prim->name, prim); + } AnalysisGraph(); } @@ -697,6 +804,11 @@ const MSCJoint MSCGraphNode::FindNode(const String& name) const { return Downcast(nodes[name]); } +const MSCPrim MSCGraphNode::FindPrim(const String& name) const { + ICHECK(prims.count(name)) << "Can not find prim " << name; + return prims[name]; +} + const MSCTensor MSCGraphNode::InputAt(int index) const { size_t v_index = CommonUtils::GetIndex(index, input_names.size()); return FindTensor(input_names[v_index]); @@ -1004,9 +1116,8 @@ void WeightGraphNode::Build(const MSCGraph& graph, const MapOutputAt(0); Map attrs; attrs.Set("producer_type", node->optype); - if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 && - node->OutputAt(0)->LayoutOf("C") >= 0 && - node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) { + if (node->optype == "reshape") { + // TODO(archermmt): check non-passby reshape attrs.Set("weight_strategy", "passby"); } else { attrs.Set("weight_strategy", relation_wtypes[node->optype]); @@ -1155,7 +1266,11 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune Downcast(p)->AddChild(new_node); } } - return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names); + Array prims; + for (const auto& name : graph->prim_names) { + prims.push_back(graph->FindPrim(name)); + } + return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names, prims); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -1168,7 +1283,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << "<"; for (size_t i = 0; i < tensor->Ndim(); i++) { - p->stream << tensor->shape[i]->value << (i == tensor->Ndim() - 1 ? "|" : ","); + const auto& prim = tensor->PrimAt(i); + p->stream << (prim.size() > 0 ? prim : StringUtils::ToString(tensor->shape[i])) + << (i == tensor->Ndim() - 1 ? "|" : ","); } p->stream << tensor->dtype; if (tensor->layout.defined()) { @@ -1177,8 +1294,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ">"; }); -#define MSC_NODE_BASE_HEAD(Stream, Joint) \ - Stream << "ID_" << Joint->index << " " << Joint->name; \ +#define MSC_NODE_BASE_HEAD(Stream, Joint, Type) \ + Stream << Type << "_" << Joint->index << " " << Joint->name; \ if (Joint->shared_ref.size() > 0) { \ Stream << "(M: " << Joint->shared_ref << ")"; \ } \ @@ -1200,7 +1317,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* joint = static_cast(node.get()); p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint); + MSC_NODE_BASE_HEAD(p->stream, joint, "N"); if (joint->inputs.size() > 0) { p->stream << " IN: "; for (size_t i = 0; i < joint->inputs.size(); i++) { @@ -1234,11 +1351,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* prim = static_cast(node.get()); + p->PrintIndent(); + MSC_NODE_BASE_HEAD(p->stream, prim, "P"); + p->stream << " OPTYPE: " << prim->optype; + if (prim->attrs.size() > 0) { + p->stream << "\n ATTRS: "; + for (const auto& pair : prim->attrs) { + p->stream << pair.first << "=" << pair.second << " "; + } + } + p->stream << "\n"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* joint = static_cast(node.get()); p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint); + MSC_NODE_BASE_HEAD(p->stream, joint, "W"); if (joint->friends.size() > 0) { p->stream << " FRIENDS: "; for (size_t i = 0; i < joint->friends.size(); i++) { @@ -1279,6 +1411,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (size_t i = 0; i < graph->output_names.size(); i++) { p->stream << graph->output_names[i] << (i == graph->output_names.size() - 1 ? ">\n" : ","); } + for (const auto& n : graph->prim_names) { + p->stream << graph->FindPrim(n) << "\n"; + } for (const auto& n : graph->node_names) { p->stream << graph->FindNode(n) << "\n"; } @@ -1288,6 +1423,8 @@ TVM_REGISTER_NODE_TYPE(MSCTensorNode); TVM_REGISTER_NODE_TYPE(MSCJointNode); +TVM_REGISTER_NODE_TYPE(MSCPrimNode); + TVM_REGISTER_NODE_TYPE(WeightJointNode); TVM_REGISTER_NODE_TYPE(MSCGraphNode); @@ -1296,8 +1433,9 @@ TVM_REGISTER_NODE_TYPE(WeightGraphNode); TVM_REGISTER_GLOBAL("msc.core.MSCTensor") .set_body_typed([](const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias) -> MSCTensor { - return MSCTensor(name, dtype, layout, shape, alias); + const Array& shape, const String& alias, + const Array& prims) -> MSCTensor { + return MSCTensor(name, dtype, layout, shape, alias, prims); }); TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") @@ -1326,6 +1464,16 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint") weights); }); +TVM_REGISTER_GLOBAL("msc.core.MSCPrim") + .set_body_typed([](Integer index, const String& name, const String& optype, + const Map& attrs, const Array& parents) -> MSCPrim { + Array b_parents; + for (const auto& p : parents) { + b_parents.push_back(p); + } + return MSCPrim(index->value, name, optype, b_parents, attrs); + }); + TVM_REGISTER_GLOBAL("msc.core.WeightJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, @@ -1349,9 +1497,9 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") TVM_REGISTER_GLOBAL("msc.core.MSCGraph") .set_body_typed([](const String& name, const Array& nodes, - const Array& input_names, - const Array& output_names) -> MSCGraph { - return MSCGraph(name, nodes, input_names, output_names); + const Array& input_names, const Array& output_names, + const Array& prims) -> MSCGraph { + return MSCGraph(name, nodes, input_names, output_names, prims); }); TVM_REGISTER_GLOBAL("msc.core.WeightGraph") @@ -1371,6 +1519,11 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") return graph->FindNode(name); }); +TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") + .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCPrim { + return graph->FindPrim(name); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasTensor(name)); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 7005518f367b..1e22e96ac951 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -48,6 +48,7 @@ struct JsonMSCTensor { std::string dtype; std::string layout; std::vector shape; + std::vector prims; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); @@ -56,6 +57,7 @@ struct JsonMSCTensor { writer->WriteObjectKeyValue("dtype", dtype); writer->WriteObjectKeyValue("layout", layout); writer->WriteObjectKeyValue("shape", shape); + writer->WriteObjectKeyValue("prims", prims); writer->EndObject(); } @@ -77,6 +79,8 @@ struct JsonMSCTensor { } else if (key == "shape") { reader->Read(&shape); bitmask |= 4; + } else if (key == "prims") { + reader->Read(&prims); } } ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, dtype and shape should be given"; @@ -147,6 +151,51 @@ struct JsonMSCJoint { } }; +/*! + * \brief Json serialize and deserialize for MSCPrim. + * MSCPrim is node in MSCGraph with name, op and attrbutes. + */ +struct JsonMSCPrim { + size_t index; + std::string name; + std::string optype; + std::vector parents; + std::unordered_map attrs; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("index", index); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("optype", optype); + writer->WriteObjectKeyValue("parents", parents); + writer->WriteObjectKeyValue("attrs", attrs); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "index") { + reader->Read(&index); + bitmask |= 1; + } else if (key == "name") { + reader->Read(&name); + bitmask |= 2; + } else if (key == "optype") { + reader->Read(&optype); + bitmask |= 4; + } else if (key == "parents") { + reader->Read(&parents); + } else if (key == "attrs") { + reader->Read(&attrs); + } + } + ICHECK_EQ(bitmask, 1 | 2 | 4) << "index, name and optype should be given"; + } +}; + /*! * \brief Json serialize and deserialize for WeightJoint. * WeightJoint is node in WeightGraph with name, wtype and attrbutes. @@ -216,6 +265,7 @@ struct JsonMSCGraph { std::vector inputs; std::vector outputs; std::vector nodes; + std::vector prims; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); @@ -223,6 +273,7 @@ struct JsonMSCGraph { writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("outputs", outputs); writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("prims", prims); writer->EndObject(); } @@ -243,6 +294,8 @@ struct JsonMSCGraph { } else if (key == "nodes") { reader->Read(&nodes); bitmask |= 8; + } else if (key == "prims") { + reader->Read(&prims); } } ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "name, inputs, outputs and nodes should be given"; @@ -297,6 +350,8 @@ class MSCTensorNode : public Object { tvm::tir::Layout layout; /*! \brief The shape of tensor. */ Array shape; + /*! \brief The prims of tensor. */ + Array prims; /*! \brief Export tensor to json. */ const JsonMSCTensor ToJson() const; /*! \brief Load tensor from json struct. */ @@ -309,6 +364,10 @@ class MSCTensorNode : public Object { const Integer DimAt(int index) const; /*! \brief Get dim at given axis. */ const Integer DimAt(const String& axis) const; + /*! \brief Get prim at given index. */ + const String PrimAt(int index) const; + /*! \brief Get prim at given axis. */ + const String PrimAt(const String& axis) const; /*! \brief Get layout index of given axis. */ int32_t LayoutOf(const String& axis) const; /*! \brief Get size of the tensor. */ @@ -322,11 +381,12 @@ class MSCTensorNode : public Object { v->Visit("dtype", &dtype); v->Visit("layout", &layout); v->Visit("shape", &shape); + v->Visit("prims", &prims); } bool SEqualReduce(const MSCTensorNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(dtype, other->dtype) && equal(shape, other->shape) && - equal(layout, other->layout); + equal(layout, other->layout) && equal(prims, other->prims); } void SHashReduce(SHashReducer hash_reduce) const { @@ -334,6 +394,7 @@ class MSCTensorNode : public Object { hash_reduce(dtype); hash_reduce(shape); hash_reduce(layout); + hash_reduce(prims); } static constexpr const char* _type_key = "msc.core.MSCTensor"; @@ -353,9 +414,11 @@ class MSCTensor : public ObjectRef { * \param layout The layout of the tensor. * \param shape The shape of the tensor. * \param alias The alias of the tensor. + * \param prims The prims of the tensor shape. */ TVM_DLL MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias = ""); + const Array& shape, const String& alias = "", + const Array& prims = Array()); /*! * \brief The json constructor. @@ -576,6 +639,76 @@ class MSCJoint : public BaseJoint { TVM_DEFINE_OBJECT_REF_METHODS(MSCJoint, BaseJoint, MSCJointNode); }; +/*! + * \brief MSCPrim in MSCGraph. + */ +class MSCPrim; +class MSCPrimNode : public BaseJointNode { + public: + /*! \brief The op of prim. */ + String optype; + /*! \brief Export prim to json. */ + const JsonMSCPrim ToJson() const; + /*! \brief Load prim from json struct. */ + void FromJson(const JsonMSCPrim& j_prim, const Map& prims); + /*! \brief Load prim from json string. */ + void FromJson(const std::string& json_str, const Map& prims); + /*! \brief Get parent from the prim. */ + const MSCPrim ParentAt(int index) const; + /*! \brief Get child from the prim. */ + const MSCPrim ChildAt(int index) const; + + void VisitAttrs(AttrVisitor* v) { + BaseJointNode::VisitAttrs(v); + v->Visit("optype", &optype); + } + + bool SEqualReduce(const MSCPrimNode* other, SEqualReducer equal) const { + return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + BaseJointNode::SHashReduce(hash_reduce); + hash_reduce(optype); + } + + static constexpr const char* _type_key = "msc.core.MSCPrim"; + TVM_DECLARE_FINAL_OBJECT_INFO(MSCPrimNode, BaseJointNode); +}; + +/*! + * \brief Managed reference to MSCPrimNode. + * \sa MSCPrimNode + */ +class MSCPrim : public BaseJoint { + public: + /*! + * \brief The constructor. + * \param index The index of the prim. + * \param name The name of the prim. + * \param optype The optype of the prim. + * \param parents The parents of the prim. + * \param attrs The attributes of the prim. + */ + TVM_DLL MSCPrim(int index, const String& name, const String& optype, + const Array& parents, + const Map& attrs = Map()); + + /*! + * \brief The json constructor. + * \param j_prim The json describe of the prim. + */ + TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const Map& prims); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the prim. + */ + TVM_DLL MSCPrim(const std::string& json_str, const Map& prims); + + TVM_DEFINE_OBJECT_REF_METHODS(MSCPrim, BaseJoint, MSCPrimNode); +}; + /*! * \brief Node in WeightGraph. */ @@ -713,6 +846,10 @@ class BaseGraph : public ObjectRef { class MSCGraph; class MSCGraphNode : public BaseGraphNode { public: + /*! \brief The shape node names in graph. */ + Array prim_names; + /*! \brief The shape nodes in graph. */ + Map prims; /*! \brief The input names of graph. */ Array input_names; /*! \brief The output names of graph. */ @@ -731,6 +868,8 @@ class MSCGraphNode : public BaseGraphNode { const String ToPrototxt() const; /*! \brief Find node in graph. */ const MSCJoint FindNode(const String& name) const; + /*! \brief Find prim in graph. */ + const MSCPrim FindPrim(const String& name) const; /*! \brief Get input from the graph. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the graph. */ @@ -769,18 +908,23 @@ class MSCGraphNode : public BaseGraphNode { void VisitAttrs(AttrVisitor* v) { BaseGraphNode::VisitAttrs(v); + v->Visit("prims", &prims); + v->Visit("prim_names", &prim_names); v->Visit("input_names", &input_names); v->Visit("output_names", &output_names); v->Visit("weight_holders", &weight_holders); } bool SEqualReduce(const MSCGraphNode* other, SEqualReducer equal) const { - return BaseGraphNode::SEqualReduce(other, equal) && equal(input_names, other->input_names) && + return BaseGraphNode::SEqualReduce(other, equal) && equal(prims, other->prims) && + equal(prim_names, other->prim_names) && equal(input_names, other->input_names) && equal(output_names, other->output_names) && equal(weight_holders, other->weight_holders); } void SHashReduce(SHashReducer hash_reduce) const { BaseGraphNode::SHashReduce(hash_reduce); + hash_reduce(prims); + hash_reduce(prim_names); hash_reduce(input_names); hash_reduce(output_names); hash_reduce(weight_holders); @@ -799,14 +943,14 @@ class MSCGraph : public BaseGraph { /*! * \brief The constructor. * \param name The name of the node. - * \param node_names The node names in the graph * \param nodes The nodes in the graph. * \param input_names The input names of the graph. * \param output_names The output names of the graph. - * \param weight_holders The weights info of the graph. + * \param prims The prims in the graph. */ TVM_DLL MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names); + const Array& input_names, const Array& output_names, + const Array& prims = Array()); /*! * \brief The json constructor. diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index a968df4204a2..20c7dbcc9172 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -138,6 +138,27 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { // Add input nodes and record inputs; Array input_names, output_names; std::set added_inputs; + // Add prims + for (const auto& p : func->params) { + if (!p->struct_info_.defined()) { + continue; + } + if (p->struct_info_.value()->IsInstance()) { + const auto& shape = ExprUtils::GetShape(p, false); + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i]->IsInstance()) { + Map attrs; + attrs.Set("producer", p->name_hint()); + attrs.Set("out_idx", "0"); + attrs.Set("dim", std::to_string(i)); + MatchOrCreatePrim(shape[i], "shape", Array(), attrs); + } + } + } else { + LOG_FATAL << "Unexpected func param " << p << "(" << p->GetTypeKey() << ")"; + } + } + for (const auto& p : func->params) { if (expr_tensor_map_.count(p)) { continue; @@ -203,7 +224,7 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { } } // build graph - const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names); + const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names, prims_); // set inputs and outputs alias if (config_.input_aliases.size() == valid_inputs.size()) { for (size_t i = 0; i < valid_inputs.size(); i++) { @@ -471,14 +492,27 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } // Build output tensor - auto build_output = [](const relax::StructInfo& sinfo, const String& node_name, - const String& layout) { + auto build_output = [this](const relax::StructInfo& sinfo, const String& node_name, + const String& layout) { ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); const auto& t_info = Downcast(sinfo); - const auto& shape_opt = t_info->GetShape(); - const auto& shape = - shape_opt.defined() ? ArrayUtils::Cast(shape_opt.value()) : Array(); + const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); + Array prims; + bool has_prims = false; + if (shape.size() > 0) { + for (const auto& s : t_info->GetShape().value()) { + if (prim_map_.count(s)) { + prims.push_back(prim_map_[s]->name); + has_prims = true; + } else { + prims.push_back(StringUtils::ToString(s)); + } + } + } + if (has_prims) { + return MSCTensor(node_name, t_info->dtype, layout, shape, "", prims); + } return MSCTensor(node_name, t_info->dtype, layout, shape); }; @@ -552,6 +586,104 @@ void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { block_stack_.pop_back(); } +#define ADD_BINARY_PRIM(TypeName) \ + if (prim->IsInstance()) { \ + const auto& binary = Downcast(prim); \ + return MatchOrCreatePrim(prim, "", {AddPrim(binary->a), AddPrim(binary->b)}); \ + } + +const MSCPrim RelaxGraphBuilder::AddPrim(const PrimExpr& prim) { + if (prim_map_.count(prim)) { + return prim_map_[prim]; + } + + // binary + ADD_BINARY_PRIM(tvm::tir::Add) + ADD_BINARY_PRIM(tvm::tir::Sub) + ADD_BINARY_PRIM(tvm::tir::Mul) + ADD_BINARY_PRIM(tvm::tir::Div) + ADD_BINARY_PRIM(tvm::tir::Mod) + ADD_BINARY_PRIM(tvm::tir::FloorDiv) + ADD_BINARY_PRIM(tvm::tir::FloorMod) + ADD_BINARY_PRIM(tvm::tir::Max) + ADD_BINARY_PRIM(tvm::tir::Min) + + // compare + ADD_BINARY_PRIM(tvm::tir::EQ) + ADD_BINARY_PRIM(tvm::tir::NE) + ADD_BINARY_PRIM(tvm::tir::LT) + ADD_BINARY_PRIM(tvm::tir::LE) + ADD_BINARY_PRIM(tvm::tir::GT) + ADD_BINARY_PRIM(tvm::tir::GE) + + // scalar + if (prim->IsInstance()) { + Map attrs; + attrs.Set("value", StringUtils::ToString(prim)); + return MatchOrCreatePrim(prim, "Int", Array(), attrs); + } + + // call + if (const auto* c_node = prim.as()) { + String optype; + Array parents; + if (const auto* op_node = c_node->op.as()) { + optype = StringUtils::Replace(op_node->name, "tir.", ""); + } else { + optype = "Prim"; + } + for (const auto& a : c_node->args) { + parents.push_back(AddPrim(a)); + } + return MatchOrCreatePrim(prim, optype, parents); + } + return MatchOrCreatePrim(prim); +} + +const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, + const Array& parents, + const Map& attrs) { + if (prim_map_.count(prim)) { + return prim_map_[prim]; + } + const auto& op_ = + optype.size() == 0 ? StringUtils::Replace(prim->GetTypeKey(), "tir.", "") : optype; + for (const auto& p : prims_) { + if (p->optype != op_ || p->attrs.size() != attrs.size() || + p->parents.size() != parents.size()) { + continue; + } + bool attrs_match = std::all_of(p->attrs.begin(), p->attrs.end(), [&attrs](const auto& pair) { + return attrs.count(pair.first) && attrs[pair.first] == pair.second; + }); + if (!attrs_match) { + continue; + } + bool parents_match = true; + for (size_t i = 0; i < parents.size(); i++) { + if (p->ParentAt(i)->name != parents[i]->name) { + parents_match = false; + break; + } + } + if (!parents_match) { + continue; + } + prim_map_.Set(prim, p); + return p; + } + String name; + if (const auto* v_node = prim.as()) { + name = v_node->name_hint; + } else { + name = StringUtils::Upper(op_) + "_" + std::to_string(prims_.size()); + } + const auto& node = MSCPrim(prims_.size(), name, op_, parents, attrs); + prims_.push_back(node); + prim_map_.Set(prim, node); + return node; +} + void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { AddNode(GetRef(op)); } @@ -649,6 +781,13 @@ const std::tuple RelaxGraphBuilder::ParseFunc(const rela return std::make_tuple(node_name, optype, layout); } +void RelaxGraphBuilder::VisitPrimExpr(const PrimExpr& prim) { + RelaxExprVisitor::VisitPrimExpr(prim); + if (!prim->IsInstance() && !prim->IsInstance()) { + AddPrim(prim); + } +} + Array RelaxGraphBuilder::GetPluginInputs(const relax::Expr& expr) { ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index d514a793475d..250fa38ef91b 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -265,6 +265,13 @@ class RelaxGraphBuilder : public RelaxExprVisitor { const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = NullOpt, const String& name = ""); + /*! \brief Create and add MSCPrim from prim*/ + const MSCPrim AddPrim(const PrimExpr& prim); + + const MSCPrim MatchOrCreatePrim(const PrimExpr& prim, const String& op = "", + const Array& parents = Array(), + const Map& attrs = Map()); + void VisitBindingBlock(const relax::BindingBlock& block) final; void VisitExpr_(const relax::ConstantNode* op) final; @@ -286,6 +293,8 @@ class RelaxGraphBuilder : public RelaxExprVisitor { void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + void VisitPrimExpr(const PrimExpr& prim) final; + private: /*! \brief Get the node_name, optype, layout for func*/ const std::tuple ParseFunc(const relax::Function& func); @@ -309,6 +318,9 @@ class RelaxGraphBuilder : public RelaxExprVisitor { // BYOC maps Map target_funcs_; Map func_params_; + // prims + Array prims_; + Map prim_map_; }; class RelaxWeightsExtractor : public RelaxExprVisitor { diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index 317a39ab4e1a..a634b8e9e36a 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -156,29 +156,30 @@ const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout, std::string new_layout = src_layout.name(); ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) << "Only support normal layout, get " << src_layout->layout; - std::vector priority_dims{"N", "C", "H", "W", "D", "G", "T"}; - size_t left_size = axes.size(); + std::set used_axes; + for (size_t i = 0; i < src_layout->layout.ndim(); i++) { + used_axes.insert(src_layout->layout[i].name()); + } + std::vector prefer_axes{"N", "C", "H", "W", "D"}; for (const auto& a : axes) { - std::string target = "U"; - if (new_layout.find("H") && !new_layout.find("W")) { - target = "W"; - } else if (new_layout.find("W") && !new_layout.find("H")) { - target = "H"; - } else if (left_size == 1 && new_layout.find("C") && !new_layout.find("D")) { - target = "D"; - } else if (left_size == 1 && new_layout.find("D") && !new_layout.find("C")) { - target = "C"; + bool use_prefer = false; + if (used_axes.size() < prefer_axes.size()) { + use_prefer = + std::all_of(prefer_axes.begin(), prefer_axes.begin() + used_axes.size(), + [&used_axes](const std::string& axis) { return used_axes.count(axis); }); + } + std::string new_axis; + char cur_axis = 'A'; + if (use_prefer) { + new_axis = prefer_axes[used_axes.size()]; } else { - for (const auto& p : priority_dims) { - int pos = new_layout.find(p); - if (pos < 0) { - target = p; - break; - } + while (used_axes.count(std::string(1, cur_axis))) { + cur_axis += 1; } + new_axis = std::string(1, cur_axis); } - new_layout = new_layout.insert(a, target); - left_size--; + used_axes.insert(new_axis); + new_layout = new_layout.insert(a, new_axis); } return LayoutDecision(new_layout); } @@ -220,6 +221,18 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout return LayoutDecision(layout_str); } +int LayoutUtils::InferBatchDim(const LayoutDecision& layout) { + if (!layout->layout.defined()) { + return -1; + } + for (size_t i = 0; i < layout->layout.ndim(); i++) { + if (layout->layout[i].name() == "N") { + return static_cast(i); + } + } + return -1; +} + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h index 7748f217d6ec..e7781a95a8f7 100644 --- a/src/contrib/msc/core/transform/layout_utils.h +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -123,6 +123,12 @@ class LayoutUtils { const Array& axes); TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes); + + /*! + * \brief Infer batch dim from the Layout + * \return The batch dim. + */ + TVM_DLL static int InferBatchDim(const LayoutDecision& layout); }; } // namespace msc diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 56517fdae8d6..a3902a44bfaa 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -34,49 +34,11 @@ namespace relax { using namespace tvm::contrib::msc; -std::tuple AccumulateMatch(const std::vector& in_shape, - const std::vector& out_shape, size_t in_start, +std::tuple AccumulateMatch(const Array& input_shape, + const Array& output_shape, size_t in_start, size_t out_start) { // find input position in_pos and output position out_pos - // cumsum(in_shape[in_start:in_ops])==cumsum(out_shape[out_start:out_pos]) - int64_t in_pos = -1; - int64_t out_pos = -1; - int64_t in_accumulate = 1; - int64_t out_accumulate = 1; - for (size_t i = in_start; i < in_shape.size(); i++) { - in_accumulate *= in_shape[i]; - out_accumulate = 1; - for (size_t j = out_start; j < out_shape.size(); j++) { - out_accumulate *= out_shape[j]; - if (in_accumulate == out_accumulate) { - in_pos = i; - out_pos = j; - break; - } else if (out_accumulate > in_accumulate) { - break; - } - } - if (in_pos >= 0) { - break; - } - } - // append tailed 1s - if (in_pos >= 0) { - int64_t in_size = static_cast(in_shape.size()); - int64_t out_size = static_cast(out_shape.size()); - while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { - in_pos++; - } - while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { - out_pos++; - } - } - return std::make_tuple(in_pos, out_pos); -} - -std::vector InferReduceAxes(const Array& input_shape, - const Array& output_shape) { - std::vector reduce_axes, out_axes; + // cumsum(in_shape[in_start:in_pos])==cumsum(out_shape[out_start:out_pos]) std::vector in_shape, out_shape; for (const auto& s : input_shape) { in_shape.push_back(Downcast(s)->value); @@ -84,71 +46,76 @@ std::vector InferReduceAxes(const Array& input_shape, for (const auto& s : output_shape) { out_shape.push_back(Downcast(s)->value); } - size_t start = 0; - while (start < in_shape.size() && out_axes.size() < out_shape.size()) { - if (in_shape[start] == out_shape[out_axes.size()]) { - out_axes.push_back(start); - start++; - } else { - int64_t in_pos, out_pos; - size_t out_start = out_axes.size(); - std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); - if (in_pos == -1) { - return std::vector(); + int64_t in_size = static_cast(in_shape.size()); + int64_t out_size = static_cast(out_shape.size()); + int64_t in_pos = in_start; + int64_t out_pos = out_start; + int64_t in_accumulate = in_shape[in_pos]; + int64_t out_accumulate = out_shape[out_pos]; + while (in_accumulate != out_accumulate) { + if (in_accumulate > out_accumulate) { + out_pos += 1; + if (out_pos >= out_size) { + return std::make_tuple(-1, -1); } - for (size_t i = out_start; i < static_cast(out_pos) + 1; i++) { - out_axes.push_back(i + 1); + out_accumulate *= out_shape[out_pos]; + } else { + in_pos += 1; + if (in_pos >= in_size) { + return std::make_tuple(-1, -1); } - start = in_pos + 1; + in_accumulate *= in_shape[in_pos]; } } - if (out_axes.size() != out_shape.size()) { - return std::vector(); - } - std::set out_axes_set; - for (const auto& a : out_axes) { - out_axes_set.insert(a); + if (in_accumulate != out_accumulate) { + return std::make_tuple(-1, -1); } - for (size_t i = 0; i < in_shape.size(); i++) { - if (!out_axes_set.count(i)) { - reduce_axes.push_back(i); + // append tailing + if (in_pos >= 0) { + while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { + in_pos++; + } + while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { + out_pos++; } } - return reduce_axes; + return std::make_tuple(in_pos - in_start, out_pos - out_start); } -std::vector InferExpandAxes(const Array& input_shape, - const Array& output_shape) { - std::vector expand_axes; - std::vector in_shape, out_shape; - for (const auto& s : input_shape) { - in_shape.push_back(Downcast(s)->value); - } - for (const auto& s : output_shape) { - out_shape.push_back(Downcast(s)->value); - } - size_t start = 0; - while (start < in_shape.size() && expand_axes.size() + in_shape.size() < out_shape.size()) { - if (in_shape[start] == out_shape[start + expand_axes.size()]) { - start++; - } else { - int64_t in_pos, out_pos; - size_t out_start = start + expand_axes.size(); - std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); - if (in_pos == -1) { - return std::vector(); +std::tuple, std::vector> InferReshapeAxes( + const Array& input_shape, const Array& output_shape, int batch_dim) { + std::vector expand_axes, reduce_axes; + size_t in_start = 0; + while (in_start < input_shape.size()) { + size_t out_start = in_start + expand_axes.size() - reduce_axes.size(); + int64_t in_dist, out_dist; + std::tie(in_dist, out_dist) = AccumulateMatch(input_shape, output_shape, in_start, out_start); + if (in_dist == -1) { + return std::make_tuple(std::vector(), std::vector()); + } + if (out_dist >= in_dist) { + for (size_t i = 0; i < static_cast(out_dist - in_dist); i++) { + if (batch_dim >= 0 && (out_start + i) == static_cast(batch_dim)) { + expand_axes.push_back(out_start + i + 1); + } else { + expand_axes.push_back(out_start + i); + } } - size_t expand_size = out_pos - in_pos - expand_axes.size(); - for (size_t i = 0; i < expand_size; i++) { - expand_axes.push_back(out_start + i); + } else { + for (size_t i = 0; i < static_cast(in_dist - out_dist); i++) { + if (batch_dim >= 0 && (in_start + i) == static_cast(batch_dim)) { + reduce_axes.push_back(in_start + i + 1); + } else { + reduce_axes.push_back(in_start + i); + } } - start = in_pos + 1; } + in_start += in_dist + 1; } - if (expand_axes.size() + in_shape.size() != out_shape.size()) { - return std::vector(); + if (input_shape.size() + expand_axes.size() - reduce_axes.size() != output_shape.size()) { + return std::make_tuple(std::vector(), std::vector()); } - return expand_axes; + return std::make_tuple(expand_axes, reduce_axes); } // Forward and Backward infer @@ -167,6 +134,11 @@ InferLayoutOutput MSCInferLayoutConv(const Call& call, data_layout = LayoutDecision(attrs->data_layout); kernel_layout = LayoutDecision(attrs->kernel_layout); out_layout = LayoutDecision(attrs->out_layout); + } else if (op_name == "relax.nn.conv2d_transpose") { + const auto* attrs = call->attrs.as(); + data_layout = LayoutDecision(attrs->data_layout); + kernel_layout = LayoutDecision(attrs->kernel_layout); + out_layout = LayoutDecision(attrs->out_layout); } return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); } @@ -213,18 +185,48 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, if (!layout_hint.defined()) { return InferLayoutOutput(); } - std::vector output_layouts; const auto& sinfo = GetStructInfo(call); if (sinfo->IsInstance()) { - output_layouts.push_back(layout_hint); - } else if (const auto* tuple_sinfo = sinfo.as()) { + return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); + } + Array output_layouts; + if (const auto* tuple_sinfo = sinfo.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { output_layouts.push_back(layout_hint); } - } else { + return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); + } + return InferLayoutOutput(); +} + +InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array input_layouts; + LayoutDecision layout_hint; + for (const auto& arg : call->args) { + const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); + if (in_layout->layout.defined()) { + if (!layout_hint.defined() || layout_hint->layout.ndim() < in_layout->layout.ndim()) { + layout_hint = in_layout; + } + } + input_layouts.push_back(in_layout); + } + if (!layout_hint.defined()) { return InferLayoutOutput(); } - return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); + const auto& sinfo = GetStructInfo(call); + if (sinfo->IsInstance()) { + return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); + } + return InferLayoutOutput(); +} + +InferLayoutOutput ForwardInferLayoutInplace(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); } InferLayoutOutput ForwardInferLayoutBinary(const Call& call, @@ -253,12 +255,6 @@ InferLayoutOutput ForwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput ForwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); -} - InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -273,9 +269,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, if (!attrs->axis.defined()) { return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -288,9 +282,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -314,9 +306,7 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -332,9 +322,7 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -353,12 +341,8 @@ InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& a_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& b_shape = - Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); - + const auto& a_shape = ExprUtils::GetShape(call->args[0]); + const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (a_shape.size() == 0) { return InferLayoutOutput(); } @@ -417,9 +401,7 @@ InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, if (!attrs->axis.defined()) { return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -438,29 +420,25 @@ InferLayoutOutput ForwardInferLayoutReshape(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& output_shape = - Downcast(GetStructInfo(call))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (input_shape.size() == 0 || output_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision output_layout; - if (input_shape.size() == output_shape.size()) { - output_layout = input_layout; - } else if (input_shape.size() > output_shape.size()) { - const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); - if (reduce_axes.size() == 0) { + LayoutDecision output_layout = input_layout; + if (input_shape.size() != output_shape.size()) { + int batch_dim = LayoutUtils::InferBatchDim(input_layout); + std::vector expand_axes, reduce_axes; + std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); + if (reduce_axes.size() == 0 && expand_axes.size() == 0) { return InferLayoutOutput(); } - output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes); - } else { - const auto& expand_axes = InferExpandAxes(input_shape, output_shape); - if (expand_axes.size() == 0) { - return InferLayoutOutput(); + if (reduce_axes.size() > 0) { + output_layout = LayoutUtils::ReduceLayout(output_layout, reduce_axes); + } + if (expand_axes.size() > 0) { + output_layout = LayoutUtils::ExpandLayout(output_layout, expand_axes); } - output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes); } return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } @@ -472,9 +450,7 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -501,12 +477,27 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, InferLayoutOutput ForwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); - if (!input_layout->layout.defined()) { + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); + if (input_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, std::vector{0}); - return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); + if (input_layout->layout.defined()) { + if (input_shape.size() == output_shape.size()) { + return InferLayoutOutput({input_layout, indices_layout}, {input_layout}, Attrs()); + } + LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, std::vector{0}); + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); + } + if (indices_layout->layout.defined()) { + size_t indices_size = indices_layout->layout.ndim(); + LayoutDecision output_layout = + LayoutUtils::ExpandLayout(indices_layout, std::vector{indices_size}); + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); + } + return InferLayoutOutput(); } InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, @@ -524,18 +515,27 @@ InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, return (*pf)(args->fields, var_layout_map); } +// nn ops +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); TVM_REGISTER_OP("relax.nn.conv2d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d_transpose") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.dropout") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutCommon); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); TVM_REGISTER_OP("relax.nn.max_pool2d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutResize2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") @@ -554,6 +554,7 @@ TVM_REGISTER_OP("relax.prod") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); TVM_REGISTER_OP("relax.std") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); + // binary ops TVM_REGISTER_OP("relax.add") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); @@ -609,14 +610,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutSqueeze); TVM_REGISTER_OP("relax.take") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutTake); - -// nn ops -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.image.resize2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutResize2d); // plugin op TVM_REGISTER_OP("relax.call_dps_packed") @@ -695,9 +690,7 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, if (attrs->keepdims) { return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -726,9 +719,7 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -759,9 +750,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& b_shape = - Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); + const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (b_shape.size() == 0) { return InferLayoutOutput(); } @@ -816,9 +805,7 @@ InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, if (attrs->keepdims) { return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -837,29 +824,25 @@ InferLayoutOutput BackwardInferLayoutReshape(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& output_shape = - Downcast(GetStructInfo(call))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (input_shape.size() == 0 || output_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision input_layout; - if (input_shape.size() == output_shape.size()) { - input_layout = output_layout; - } else if (input_shape.size() > output_shape.size()) { - const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); - if (reduce_axes.size() == 0) { + LayoutDecision input_layout = output_layout; + if (input_shape.size() != output_shape.size()) { + int batch_dim = LayoutUtils::InferBatchDim(output_layout); + std::vector expand_axes, reduce_axes; + std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); + if (reduce_axes.size() == 0 && expand_axes.size() == 0) { return InferLayoutOutput(); } - input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes); - } else { - const auto& expand_axes = InferExpandAxes(input_shape, output_shape); - if (expand_axes.size() == 0) { - return InferLayoutOutput(); + if (expand_axes.size() > 0) { + input_layout = LayoutUtils::ReduceLayout(input_layout, expand_axes); + } + if (reduce_axes.size() > 0) { + input_layout = LayoutUtils::ExpandLayout(input_layout, reduce_axes); } - input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes); } return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } @@ -871,9 +854,7 @@ InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -901,12 +882,28 @@ InferLayoutOutput BackwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); - return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + if (!indices_layout.defined()) { + indices_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); + } + if (input_shape.size() == output_shape.size()) { + return InferLayoutOutput({output_layout, indices_layout}, {output_layout}, Attrs()); + } + if (!input_layout.defined()) { + input_layout = LayoutUtils::ExpandLayout(output_layout, std::vector{0}); + } + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } + InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -925,18 +922,25 @@ InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); } +// nn ops +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); TVM_REGISTER_OP("relax.nn.conv2d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d_transpose") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); TVM_REGISTER_OP("relax.nn.max_pool2d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutResize2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") @@ -1013,14 +1017,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutSqueeze); TVM_REGISTER_OP("relax.take") .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutTake); - -// nn ops -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.image.resize2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutResize2d); class LayoutInfer : public ExprVisitor { public: @@ -1268,9 +1266,13 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(call->args[i], var_layout_map_[func->params[i]]); } } - if (func->body->body->IsInstance() && - var_layout_map_.count(Downcast(func->body->body))) { - SetExprLayout(ret, var_layout_map_[Downcast(func->body->body)]); + if (const auto* b_node = func->body.as()) { + if (b_node->body->IsInstance() && + var_layout_map_.count(Downcast(b_node->body))) { + SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); + } + } else { + LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; } } @@ -1284,9 +1286,13 @@ class LayoutInfer : public ExprVisitor { if (producer->IsInstance() && local_funcs_.count(Downcast(producer)->op)) { const auto& caller = local_funcs_[Downcast(producer)->op]; - if (caller->body->body->IsInstance() && - var_map_.count(Downcast(caller->body->body))) { - SetExprLayout(caller->body->body, param_layout); + if (const auto* b_node = caller->body.as()) { + if (b_node->body->IsInstance() && + var_map_.count(Downcast(b_node->body))) { + SetExprLayout(b_node->body, param_layout); + } + } else { + LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body; } } } @@ -1298,7 +1304,7 @@ class LayoutInfer : public ExprVisitor { bool infered_; Map var_map_; Array ordered_exprs_; - std::unordered_map var_layout_map_; + std::unordered_map var_layout_map_; Map local_funcs_; }; // class LayoutInfer diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 9e437f705c34..634dd7969889 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -141,7 +141,7 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTFV1OpCodes(); auto it = ops_map->find(node->optype); ICHECK(it != ops_map->end()) << "Unsupported tensorflow op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -154,6 +154,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 717eb75e1f36..a9c16994e5b6 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -544,7 +544,7 @@ const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -578,6 +578,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 54859ad0ce89..86351bdd060b 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -142,7 +142,7 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; - it->second->Config(node, config(), is_init_); + it->second->Config(node, config(), is_init_, prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -155,6 +155,7 @@ TVM_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index e355626f859f..9ae825b804aa 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -202,6 +202,13 @@ class TorchClipCodeGen : public TorchOpCode { } }; +class TorchConcatCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchConcatCodeGen); + + protected: + void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } +}; + class TorchConstantCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen); @@ -298,8 +305,8 @@ class TorchEmbeddingCodeGen : public TorchOpCode { void CodeGenInit() final { const auto& weight = node()->WeightAt("weight"); stack_.op_call() - .call_arg(weight->DimAt("W"), "num_embeddings") - .call_arg(weight->DimAt("E"), "embedding_dim"); + .call_arg(weight->DimAt(0), "num_embeddings") + .call_arg(weight->DimAt(1), "embedding_dim"); } }; @@ -706,6 +713,7 @@ const std::shared_ptr>> map->emplace("astype", std::make_shared("", "to")); map->emplace("broadcast_to", std::make_shared("", "expand")); map->emplace("clip", std::make_shared("", "torch.clamp")); + map->emplace("concat", std::make_shared("", "torch.cat")); map->emplace("cumsum", std::make_shared("", "torch.cumsum")); map->emplace("expand_dims", std::make_shared("", "torch.unsqueeze")); map->emplace("permute_dims", std::make_shared("", "torch.permute")); diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h index 6fe5cf5f96c4..80b7f5c60d1d 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ b/src/contrib/msc/framework/torch/torch_opcode.h @@ -55,9 +55,9 @@ class TorchOpCode : public BaseOpCode { } /*! \brief Config the TorchOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config, - bool is_init) { - BaseOpCode::Config(node, config); + void Config(const MSCJoint& node, const std::shared_ptr config, bool is_init, + const Map& prims) { + BaseOpCode::Config(node, config, prims); is_init_ = is_init; module_ref_ = "self." + StringUtils::Replace(node->name, ".", "_"); } diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 783551eed35b..5443cdc96a05 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -187,11 +187,21 @@ void RelaxCodeGen::CodeGenInference() { } } +const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { + if (prim->optype == "shape") { + const auto& producer = graph()->FindNode(prim->GetTypeAttr("producer")); + int out_idx = prim->GetTypeAttr("out_idx"); + const auto& dim = prim->GetTypeAttr("dim"); + return IdxOutputBase(producer, out_idx) + ".struct_info.shape[" + dim + "]"; + } + return PyCodeGen::DescribePrim(prim); +} + const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -204,6 +214,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tvm/codegen.h b/src/contrib/msc/framework/tvm/codegen.h index 944d4cdfe1cc..249105b5a50b 100644 --- a/src/contrib/msc/framework/tvm/codegen.h +++ b/src/contrib/msc/framework/tvm/codegen.h @@ -55,6 +55,9 @@ class RelaxCodeGen : public PyCodeGen { /*! \brief Stack the docs for the graph inference*/ void CodeGenInference() final; + /*! \brief Describe the prim*/ + const String DescribePrim(const MSCPrim& prim) final; + /*! \brief Get the docs for the op*/ const Array GetOpCodes(const MSCJoint& node) final; diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 0b7ef6aa825e..1913e8ecda8e 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -562,12 +562,8 @@ class RelaxReshapeCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - if (config()->from_relay) { - stack_.op_list_arg("newshape", "shape"); - } else { - stack_.op_list_arg("shape"); - } + const auto& out_shape = GetPrims(node()->OutputAt(0)); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(out_shape), "shape"); } }; diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index d02767208206..60c8a73dcc67 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -14,20 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """ Test graph builder && graph. """ +import pytest import torch from torch import fx from torch.nn import Module import tvm.testing from tvm.relax.frontend.torch import from_fx -from tvm.contrib.msc.core.frontend import translate +from tvm.contrib.msc.core.frontend import translate, normalize_inputs from tvm.contrib.msc.core import utils as msc_utils def verify_model(torch_model, input_info, expected): + input_info = normalize_inputs(input_info) graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): mod = from_fx(graph_model, input_info) @@ -38,7 +41,8 @@ def verify_model(torch_model, input_info, expected): ) -def test_conv1d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_conv1d(dynamic): """test graph builder for conv1d""" class Conv1D1(Module): @@ -49,12 +53,6 @@ def __init__(self): def forward(self, data): return self.conv(data) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32", "layout": "NCW"}], - "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1}, - } - class Conv1D2(Module): def __init__(self): super().__init__() @@ -63,18 +61,28 @@ def __init__(self): def forward(self, data): return self.conv(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], + "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], + "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32", "layout": "NCW"}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], + "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], "nodes": {"total": 2, "input": 1, "nn.conv1d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10], "float32")] + input_info = [([bz, 3, 10], "float32")] verify_model(Conv1D1(), input_info, expected1) verify_model(Conv1D2(), input_info, expected2) -def test_conv2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_conv2d(dynamic): """test graph builder for conv2d""" class Conv2D1(Module): @@ -85,44 +93,49 @@ def __init__(self): def forward(self, data): return self.conv(data) + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, data): + return self.conv(data) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "conv2d", - "shape": [1, 6, 4, 4], + "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, } - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "conv2d", "shape": [1, 6, 4, 4], "dtype": "float32", "layout": "NCHW"} + {"name": "conv2d", "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.conv2d": 1}, } - input_info = [([1, 3, 10, 10], "float32")] + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Conv2D1(), input_info, expected1) verify_model(Conv2D2(), input_info, expected2) -def test_linear(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_linear(dynamic): """test graph builder for linear""" class Dense1(Module): @@ -133,123 +146,139 @@ def __init__(self): def forward(self, data): return self.linear(data) + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, data): + return self.linear(data) + + class MatMul1(Module): + def forward(self, x, y): + return torch.matmul(x, y) + + bz = "bz" if dynamic else 1 + mdim = "mdim" if dynamic else 10 + ndim = "ndim" if dynamic else 20 + kdim = "kdim" if dynamic else 30 + expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "matmul", - "shape": [1, 3, 10, 7], + "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.linear_bias": 1}, } - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "matmul", "shape": [1, 3, 10, 7], "dtype": "float32", "layout": "NCHW"} + {"name": "matmul", "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "msc.linear": 1}, } - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - expected3 = { "inputs": [ - {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": "IO"}, + {"name": "inp_0", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_1", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, ], - "outputs": [{"name": "matmul", "shape": [10, 10], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "matmul", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], "nodes": {"total": 3, "input": 2, "matmul": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 3, "shape": 3} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Dense1(), input_info, expected1) verify_model(Dense2(), input_info, expected2) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")], expected3) + verify_model(MatMul1(), [([mdim, kdim], "float32"), ([kdim, ndim], "float32")], expected3) -def test_bmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_bmm(dynamic): """test graph builder for bmm""" class BMM(Module): def forward(self, x, y): return torch.bmm(x, y) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_1", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], "outputs": [ - {"name": "matmul", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"} + {"name": "matmul", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} ], "nodes": {"total": 3, "input": 2, "matmul": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] + input_info = [((bz, 128, 256), "float32"), ((bz, 256, 512), "float32")] verify_model(BMM(), input_info, expected) -def test_baddbmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_baddbmm(dynamic): """test graph builder for baddbmm""" class BAddBMM1(Module): def forward(self, c, x, y): return torch.baddbmm(c, x, y) + class BAddBMM2(Module): + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], - "outputs": [{"name": "add", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"}], + "outputs": [{"name": "add", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}], "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, } - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], "outputs": [ - {"name": "multiply", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"} + {"name": "multiply", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} ], "nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1, "multiply": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} input_info = [ - ((4, 128, 512), "float32"), - ((4, 128, 256), "float32"), - ((4, 256, 512), "float32"), + ((bz, 128, 512), "float32"), + ((bz, 128, 256), "float32"), + ((bz, 256, 512), "float32"), ] verify_model(BAddBMM1(), input_info, expected1) verify_model(BAddBMM2(), input_info, expected2) -def test_relu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_relu(dynamic): """test graph builder for relu""" class ReLU(Module): @@ -264,18 +293,22 @@ class ReLU1(Module): def forward(self, data): return torch.nn.functional.relu(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "relu", "shape": [10, 10], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "relu", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "nn.relu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([10, 10], "float32")] + input_info = [([bz, 10], "float32")] verify_model(ReLU(), input_info, expected) verify_model(ReLU1(), input_info, expected) -def test_relu6(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_relu6(dynamic): """test graph builder for relu6""" class ReLU6(Module): @@ -286,16 +319,21 @@ def __init__(self): def forward(self, data): return self.relu6(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "clip", "shape": [bz, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "clip": 1}, } - input_info = [([10, 10], "float32")] + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} + + input_info = [([bz, 10], "float32")] verify_model(ReLU6(), input_info, expected) -def test_maxpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_maxpool2d(dynamic): """test graph builder for maxpool2d""" class MaxPool2d(Module): @@ -306,16 +344,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - class MaxPool2d2(Module): def __init__(self): super().__init__() @@ -324,16 +352,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 4, 4], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - class MaxPool2d3(Module): def __init__(self): super().__init__() @@ -342,23 +360,47 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "max_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, + } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "max_pool2d", "shape": [bz, 3, 4, 4], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, + } expected3 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} + {"name": "max_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(MaxPool2d(), input_info, expected1) verify_model(MaxPool2d2(), input_info, expected2) verify_model(MaxPool2d3(), input_info, expected3) -def test_avgpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_avgpool2d(dynamic): """test graph builder for avgpool2d""" class AvgPool2d(Module): @@ -369,16 +411,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "avg_pool2d", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, - } - class AvgPool2d2(Module): def __init__(self): super().__init__() @@ -387,22 +419,36 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "avg_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, + } expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "avg_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} + {"name": "avg_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(AvgPool2d(), input_info, expected1) verify_model(AvgPool2d2(), input_info, expected2) -def test_adaptive_avgpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_adaptive_avgpool2d(dynamic): """test graph builder for adaptive_avgpool2d""" class AdaptiveAvgPool2d0(Module): @@ -413,26 +459,30 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "adaptive_avg_pool2d", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(AdaptiveAvgPool2d0(), input_info, expected) -def test_flatten(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_flatten(dynamic): """test graph builder for flatten""" class Flatten(Module): @@ -443,18 +493,26 @@ def __init__(self): def forward(self, data): return self.f(data) + bz = "bz" if dynamic else 1 + dim = "dim" if dynamic else 10 + out_dim = "MUL_3" if dynamic else 100 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [1, 3, 100], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, dim], "dtype": "float32", "layout": ""}], + "outputs": [ + {"name": "reshape", "shape": [bz, 3, out_dim], "dtype": "float32", "layout": ""} + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, dim], "float32")] verify_model(Flatten(), input_info, expected) verify_model(torch.nn.Flatten(2, -1), input_info, expected) -def test_batchnorm2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_batchnorm2d(dynamic): """test graph builder for batchnorm2d""" class BatchNorm2d(Module): @@ -465,26 +523,30 @@ def __init__(self): def forward(self, data): return self.batchnorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "batch_norm.0", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(BatchNorm2d(), input_info, expected) -def test_embedding(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_embedding(dynamic): """test graph builder for embedding""" class Embedding(Module): @@ -495,23 +557,34 @@ def __init__(self): def forward(self, data): return self.embedding(data) + vocab = "vocab" if dynamic else 4 expected1 = { - "inputs": [{"name": "inp_0", "shape": [4], "dtype": "int64", "layout": "A"}], - "outputs": [{"name": "take", "shape": [4, 3], "dtype": "float32", "layout": "NA"}], + "inputs": [{"name": "inp_0", "shape": [vocab], "dtype": "int64", "layout": "A"}], + "outputs": [{"name": "take", "shape": [vocab, 3], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [4, 5], "dtype": "int64", "layout": "AB"}], - "outputs": [{"name": "take", "shape": [4, 5, 3], "dtype": "float32", "layout": "CNB"}], + "inputs": [{"name": "inp_0", "shape": [vocab, 5], "dtype": "int64", "layout": "AB"}], + "outputs": [ + { + "name": "take", + "shape": [vocab, 5, 3], + "dtype": "float32", + "layout": "" if dynamic else "CBA", + } + ], "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - verify_model(Embedding(), [([4], "int64")], expected1) - verify_model(Embedding(), [([4, 5], "int64")], expected2) + verify_model(Embedding(), [([vocab], "int64")], expected1) + verify_model(Embedding(), [([vocab, 5], "int64")], expected2) -def test_dropout(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_dropout(dynamic): """test graph builder for dropout""" class Dropout1(Module): @@ -526,18 +599,22 @@ class Dropout2(Module): def forward(self, data): return torch.dropout(data, 0.5, train=True) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 1, "input": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Dropout1(), input_info, expected) verify_model(Dropout2(), input_info, expected) -def test_layernorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_layernorm(dynamic): """test graph builder for layernorm""" class LayerNorm(Module): @@ -548,21 +625,25 @@ def __init__(self): def forward(self, data): return self.layernorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(LayerNorm(), input_info, expected) -def test_functional_layernorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_functional_layernorm(dynamic): """test graph builder for functional_layernorm""" class LayerNorm(Module): @@ -576,21 +657,25 @@ def forward(self, data): data, self.weight.shape, self.weight, self.bias, 1e-5 ) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(LayerNorm((10, 10)), input_info, expected) -def test_cross_entropy(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_cross_entropy(dynamic): """test graph builder for cross_entropy""" class CrossEntropy1(Module): @@ -601,15 +686,6 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, - } - class CrossEntropy2(Module): def __init__(self): super().__init__() @@ -619,15 +695,6 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1}, - } - class CrossEntropy3(Module): def __init__(self): super().__init__() @@ -636,42 +703,68 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, + ], + "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], + "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, + } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, + ], + "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], + "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1}, + } expected3 = { "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, ], "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 1, "shape": 1} - input_info = [([3, 2], "float32"), ([3], "int32")] + input_info = [([bz, 2], "float32"), ([bz], "int32")] verify_model(CrossEntropy1(), input_info, expected1) verify_model(CrossEntropy2(), input_info, expected2) verify_model(CrossEntropy3(), input_info, expected3) -def test_functional_cross_entropy(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_functional_cross_entropy(dynamic): """test graph builder for functional_cross_entropy""" class CrossEntropy(Module): def forward(self, logits, targets): return torch.nn.functional.cross_entropy(logits, targets) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [3, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, ], "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([3, 10], "float32"), ([3], "int32")] + input_info = [([bz, 10], "float32"), ([bz], "int32")] verify_model(CrossEntropy(), input_info, expected) -def test_silu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_silu(dynamic): """test graph builder for silu""" class SiLU(Module): @@ -686,22 +779,26 @@ class SiLU2(Module): def forward(self, data): return torch.nn.functional.silu(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "silu", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "silu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.silu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(SiLU(), input_info, expected) verify_model(SiLU2(), input_info, expected) -def test_groupnorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_groupnorm(dynamic): """test graph builder for groupnorm""" class GroupNorm(Module): @@ -712,21 +809,25 @@ def __init__(self): def forward(self, data): return self.groupnorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "group_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "group_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.group_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(GroupNorm(), input_info, expected) -def test_softmax(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_softmax(dynamic): """test graph builder for softmax""" class Softmax(Module): @@ -737,51 +838,62 @@ def __init__(self): def forward(self, data): return self.softmax(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "softmax", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "softmax", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.softmax": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Softmax(), input_info, expected) -def test_binary(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_binary(dynamic): """test graph builder for binary""" - input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] - input_info2 = [([1, 3, 10, 10], "float32")] + bz = "bz" if dynamic else 1 + input_info1 = [([bz, 3, 10, 10], "float32"), ([bz, 3, 10, 10], "float32")] + input_info2 = [([bz, 3, 10, 10], "float32")] # Add class Add1(Module): def forward(self, lhs, rhs): return lhs + rhs + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + expected_add1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + ], + "outputs": [ + {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 3, "input": 2, "add": 1}, } - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - expected_add2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 3, "input": 1, "constant": 1, "add": 1}, } + if dynamic: + expected_add1["prims"] = {"total": 1, "shape": 1} + expected_add2["prims"] = {"total": 1, "shape": 1} verify_model(Add1(), input_info1, expected_add1) verify_model(Add2(), input_info2, expected_add2) @@ -791,30 +903,32 @@ class Sub1(Module): def forward(self, lhs, rhs): return lhs - rhs + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + expected_sub1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "subtract": 1}, } - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - expected_sub2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1}, } + if dynamic: + expected_sub1["prims"] = {"total": 1, "shape": 1} + expected_sub2["prims"] = {"total": 1, "shape": 1} verify_model(Sub1(), input_info1, expected_sub1) verify_model(Sub2(), input_info2, expected_sub2) @@ -824,30 +938,32 @@ class Mul1(Module): def forward(self, lhs, rhs): return lhs * rhs + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + expected_mul1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "multiply": 1}, } - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - expected_mul2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1}, } + if dynamic: + expected_mul1["prims"] = {"total": 1, "shape": 1} + expected_mul2["prims"] = {"total": 1, "shape": 1} verify_model(Mul1(), input_info1, expected_mul1) verify_model(Mul2(), input_info2, expected_mul2) @@ -857,30 +973,32 @@ class TrueDiv1(Module): def forward(self, lhs, rhs): return lhs / rhs + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + expected_div1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "divide": 1}, } - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - expected_div2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1}, } + if dynamic: + expected_div1["prims"] = {"total": 1, "shape": 1} + expected_div2["prims"] = {"total": 1, "shape": 1} verify_model(TrueDiv1(), input_info1, expected_div1) verify_model(TrueDiv2(), input_info2, expected_div2) @@ -890,40 +1008,42 @@ class FloorDiv1(Module): def forward(self, lhs, rhs): return lhs // rhs + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + expected_floordiv1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ { "name": "floor_divide", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 3, "input": 2, "floor_divide": 1}, } - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - expected_floordiv2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ { "name": "floor_divide", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1}, } + if dynamic: + expected_floordiv1["prims"] = {"total": 1, "shape": 1} + expected_floordiv2["prims"] = {"total": 1, "shape": 1} verify_model(FloorDiv1(), input_info1, expected_floordiv1) verify_model(FloorDiv2(), input_info2, expected_floordiv2) @@ -933,30 +1053,32 @@ class Power1(Module): def forward(self, lhs, rhs): return lhs**rhs + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + expected_power1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "power": 1}, } - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - expected_power2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "power": 1}, } + if dynamic: + expected_power1["prims"] = {"total": 1, "shape": 1} + expected_power2["prims"] = {"total": 1, "shape": 1} verify_model(Power1(), input_info1, expected_power1) verify_model(Power2(), input_info2, expected_power2) @@ -966,176 +1088,214 @@ class LT1(Module): def forward(self, lhs, rhs): return lhs < rhs + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + expected_lt1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], - "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], + "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], "nodes": {"total": 3, "input": 2, "less": 1}, } - - class LT2(Module): - def forward(self, lhs): - return lhs < 1.0 - expected_lt2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], + "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], "nodes": {"total": 3, "input": 1, "constant": 1, "less": 1}, } + if dynamic: + expected_lt1["prims"] = {"total": 1, "shape": 1} + expected_lt2["prims"] = {"total": 1, "shape": 1} verify_model(LT1(), input_info1, expected_lt1) verify_model(LT2(), input_info2, expected_lt2) -def test_size(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_size(dynamic): """test graph builder for size""" class Size(Module): def forward(self, data): return data.size() + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], "nodes": {"total": 2, "input": 1, "shape": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Size(), input_info, expected) -def test_squeeze(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_squeeze(dynamic): """test graph builder for squeeze""" class Squeeze1(Module): def forward(self, data): return data.squeeze(1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": "float32", "layout": "ANBC"}], - "outputs": [{"name": "squeeze", "shape": [3, 4, 1], "dtype": "float32", "layout": "ABC"}], - "nodes": {"total": 2, "input": 1, "squeeze": 1}, - } - class Squeeze2(Module): def forward(self, data): return data.squeeze() - expected2 = { - "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": "float32", "layout": "ANBC"}], - "outputs": [{"name": "squeeze", "shape": [3, 4], "dtype": "float32", "layout": "AB"}], + bz = "bz" if dynamic else 10 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ADBC"}], + "outputs": [{"name": "squeeze", "shape": [bz, 4, 1], "dtype": "float32", "layout": "ABC"}], "nodes": {"total": 2, "input": 1, "squeeze": 1}, } - - input_info = [([3, 1, 4, 1], "float32")] + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} + ], + "outputs": [{"name": "squeeze", "shape": [], "dtype": "float32", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "squeeze": 1}, + "prims": {"total": 1, "shape": 1}, + } + else: + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} + ], + "outputs": [{"name": "squeeze", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "squeeze": 1}, + } + input_info = [([bz, 1, 4, 1], "float32")] verify_model(Squeeze1(), input_info, expected1) verify_model(Squeeze2(), input_info, expected2) -def test_unsqueeze(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unsqueeze(dynamic): """test graph builder for unsqueeze""" class Unsqueeze1(Module): def forward(self, data): return data.unsqueeze(1) + class Unsqueeze2(Module): + def forward(self, data): + return data.unsqueeze(-1) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ACDE"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ACDE"} ], "outputs": [ { "name": "expand_dims", - "shape": [1, 1, 3, 10, 10], + "shape": [bz, 1, 3, 10, 10], "dtype": "float32", "layout": "ABCDE", } ], "nodes": {"total": 2, "input": 1, "expand_dims": 1}, } - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCE"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCE"} ], "outputs": [ { "name": "expand_dims", - "shape": [1, 3, 10, 10, 1], + "shape": [bz, 3, 10, 10, 1], "dtype": "float32", "layout": "ABCDE", } ], "nodes": {"total": 2, "input": 1, "expand_dims": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Unsqueeze1(), input_info, expected1) verify_model(Unsqueeze2(), input_info, expected2) -def test_getattr(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_getattr(dynamic): """test graph builder for getattr""" class GetAttr1(Module): def forward(self, data): return data.shape + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], "nodes": {"total": 2, "input": 1, "shape": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(GetAttr1(), input_info, expected) -def test_getitem(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_getitem(dynamic): """test graph builder for getitem""" class Slice1(Module): def forward(self, x): return x[0, 1::2, :, :3] + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "reshape", "shape": [1, 1, 10, 3], "dtype": "float32", "layout": "ABCD"} + { + "name": "reshape", + "shape": ["MIN_2" if dynamic else 1, 1, 10, 3], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, } - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - expected2 = { - "inputs": [{"name": "inp_0", "shape": [8, 16], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 16], "dtype": "float32", "layout": "AB"}], "outputs": [ - {"name": "reshape", "shape": [8, 1, 1, 16, 1], "dtype": "float32", "layout": "ANCHB"} + {"name": "reshape", "shape": [bz, 1, 1, 16, 1], "dtype": "float32", "layout": "CDAEB"} ], "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, } + if dynamic: + expected1["prims"] = {"total": 3, "shape": 1, "Int": 1, "Min": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Slice1(), [([1, 3, 10, 10], "float32")], expected1) - verify_model(Slice2(), [([8, 16], "float32")], expected2) + verify_model(Slice1(), [([bz, 3, 10, 10], "float32")], expected1) + verify_model(Slice2(), [([bz, 16], "float32")], expected2) -def test_unary(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unary(dynamic): """test graph builder for unary""" - input_info = [([1, 3, 10, 10], "float32")] + bz = "bz" if dynamic else 1 + input_info = [([bz, 3, 10, 10], "float32")] # sin class Sin(Module): @@ -1144,11 +1304,15 @@ def forward(self, data): expected_sin = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "sin", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "sin", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "sin": 1}, } + if dynamic: + expected_sin["prims"] = {"total": 1, "shape": 1} verify_model(Sin(), input_info, expected_sin) @@ -1159,11 +1323,15 @@ def forward(self, data): expected_cos = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "cos", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "cos", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "cos": 1}, } + if dynamic: + expected_cos["prims"] = {"total": 1, "shape": 1} verify_model(Cos(), input_info, expected_cos) @@ -1174,11 +1342,15 @@ def forward(self, data): expected_exp = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "exp", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "exp", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "exp": 1}, } + if dynamic: + expected_exp["prims"] = {"total": 1, "shape": 1} verify_model(Exp(), input_info, expected_exp) @@ -1189,13 +1361,15 @@ def forward(self, data): expected_sqrt = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "sqrt", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "sqrt", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "sqrt": 1}, } + if dynamic: + expected_sqrt["prims"] = {"total": 1, "shape": 1} verify_model(Sqrt(), input_info, expected_sqrt) @@ -1206,13 +1380,15 @@ def forward(self, data): expected_sigmoid = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "sigmoid", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "sigmoid", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "sigmoid": 1}, } + if dynamic: + expected_sigmoid["prims"] = {"total": 1, "shape": 1} verify_model(Sigmoid(), input_info, expected_sigmoid) @@ -1223,123 +1399,144 @@ def forward(self, data): expected_round = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "round", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "round", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "round": 1}, } + if dynamic: + expected_round["prims"] = {"total": 1, "shape": 1} verify_model(Round(), input_info, expected_round) -def test_gelu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_gelu(dynamic): """test graph builder for gelu""" class Gelu(Module): def forward(self, data): return torch.nn.functional.gelu(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "gelu", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "gelu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.gelu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Gelu(), input_info, expected) -def test_tanh(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_tanh(dynamic): """test graph builder for tanh""" class Tanh(Module): def forward(self, data): return torch.tanh(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "tanh", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "tanh", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "tanh": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Tanh(), input_info, expected) -def test_clamp(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_clamp(dynamic): """test graph builder for clamp""" class Clamp(Module): def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "clip", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "clip": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Clamp(), input_info, expected) -def test_interpolate(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_interpolate(dynamic): """test graph builder for interpolate""" class Interpolate(Module): def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "resize2d", "shape": [1, 3, 5, 5], "dtype": "float32", "layout": "NCHW"} + {"name": "resize2d", "shape": [bz, 3, 5, 5], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "image.resize2d": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Interpolate(), input_info, expected) -def test_addmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_addmm(dynamic): """test graph builder for addmm""" class Addmm(Module): def forward(self, x_1, x_2, x_3): return torch.addmm(x_1, x_2, x_3) + mdim = "mdim" if dynamic else 10 + ndim = "ndim" if dynamic else 20 + kdim = "kdim" if dynamic else 30 expected = { "inputs": [ - {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_2", "shape": [10, 10], "dtype": "float32", "layout": "IO"}, + {"name": "inp_0", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_1", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_2", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, ], - "outputs": [{"name": "add", "shape": [10, 10], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "add", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 3} - input_info = [ - ([10, 10], "float32"), - ([10, 10], "float32"), - ([10, 10], "float32"), - ] + input_info = [([mdim, ndim], "float32"), ([mdim, kdim], "float32"), ([kdim, ndim], "float32")] verify_model(Addmm(), input_info, expected) -def test_split(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_split(dynamic): """test graph builder for split""" class Split1(Module): @@ -1350,98 +1547,114 @@ class Split2(Module): def forward(self, data): return torch.split(data, [1, 2], dim=1) + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Split1(), input_info, expected1) verify_model(Split2(), input_info, expected2) -def test_unbind(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unbind(dynamic): """test graph builder for unbind""" class Unbind(Module): def forward(self, data): return torch.unbind(data, dim=1) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "tuple_0", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_1", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_2", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_0", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_1", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_2", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, ], "nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Unbind(), input_info, expected) -def test_cumsum(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_cumsum(dynamic): """test graph builder for cumsum""" class Cumsum(Module): def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "cumsum", "shape": [1, 2, 3, 4], "dtype": "int32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "cumsum", "shape": [bz, 2, 3, 4], "dtype": "int32", "layout": ""}], "nodes": {"total": 2, "input": 1, "cumsum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Cumsum(), input_info, expected) -def test_chunk(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_chunk(dynamic): """test graph builder for chunk""" class Chunk(Module): def forward(self, data): return torch.chunk(data, 3, dim=1) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Chunk(), input_info, expected) -def test_inplace_fill(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_inplace_fill(dynamic): """test graph builder for inplace_fill""" class InplaceFill(Module): @@ -1449,13 +1662,21 @@ def forward(self, data): data.fill_(1.5) return data - expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - verify_model(InplaceFill(), [([10, 10], "float32")], expected) + bz = "bz" if dynamic else 1 + if dynamic: + expected = { + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "full", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1}, + "prims": {"total": 1, "shape": 1}, + } + else: + expected = { + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "const", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "constant": 1}, + } + verify_model(InplaceFill(), [([bz, 10], "float32")], expected) def test_arange(): @@ -1517,7 +1738,8 @@ def forward(self): verify_model(Empty2(), [([10, 10], "float32")], expected2) -def test_tril(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_tril(dynamic): """test graph builder for tril""" class Tril(Module): @@ -1529,18 +1751,23 @@ def forward(self, data): data.tril_(1) return data + row = "row" if dynamic else 10 + col = "col" if dynamic else 10 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tril", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "tril", "shape": [row, col], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "tril": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([10, 10], "float32")] + input_info = [([row, col], "float32")] verify_model(Tril(), input_info, expected) verify_model(InplaceTril(), input_info, expected) -def test_triu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_triu(dynamic): """test graph builder for triu""" class Triu(Module): @@ -1552,13 +1779,17 @@ def forward(self, data): data.triu_(1) return data + row = "row" if dynamic else 10 + col = "col" if dynamic else 10 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "triu", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "triu", "shape": [row, col], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "triu": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([10, 10], "float32")] + input_info = [([row, col], "float32")] verify_model(Triu(), input_info, expected) verify_model(InplaceTriu(), input_info, expected) @@ -1580,7 +1811,8 @@ def forward(self, x): verify_model(NewOnes(), input_info, expected) -def test_expand(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_expand(dynamic): """test graph builder for expand""" class Expand1(Module): @@ -1591,20 +1823,24 @@ class Expand2(Module): def forward(self, x): return x.expand(4, -1, -1, 4) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], "outputs": [ {"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype": "float32", "layout": ""} ], "nodes": {"total": 2, "input": 1, "broadcast_to": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Expand1(), input_info, expected) verify_model(Expand2(), input_info, expected) -def test_reduce(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_reduce(dynamic): """test graph builder for reduce""" # sum @@ -1612,20 +1848,25 @@ class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ANCB"}], - "outputs": [{"name": "sum", "shape": [1, 4], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ACDB"}], + "outputs": [{"name": "sum", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "sum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Sum(), input_info, expected) -def test_datatype(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_datatype(dynamic): """test graph builder for datatype""" - input_info = [([1, 2, 3, 4], "float32")] + bz = "bz" if dynamic else 1 + input_info = [([bz, 2, 3, 4], "float32")] # float class ToFloat(Module): @@ -1633,12 +1874,14 @@ def forward(self, x): return x.float() expected1 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} verify_model(ToFloat(), input_info, expected1) @@ -1648,12 +1891,14 @@ def forward(self, x): return x.half() expected2 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float16", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float16", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected2["prims"] = {"total": 1, "shape": 1} verify_model(ToHalf(), input_info, expected2) @@ -1663,12 +1908,14 @@ def forward(self, x): return x.type(torch.float32) expected3 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected3["prims"] = {"total": 1, "shape": 1} # type class TypeFromAttr(Module): @@ -1676,12 +1923,14 @@ def forward(self, x): return x.type(x.getattr("dtype")) expected4 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected4["prims"] = {"total": 1, "shape": 1} # astype class AsType(Module): @@ -1689,91 +1938,140 @@ def forward(self, x): return x.astype(torch.float32) expected5 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected5["prims"] = {"total": 1, "shape": 1} verify_model(Type(), input_info, expected3) verify_model(TypeFromAttr(), input_info, expected4) verify_model(AsType(), input_info, expected5) -def test_permute(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_permute(dynamic): """test graph builder for permute""" class Permute(Module): def forward(self, x): return x.permute(0, 3, 2, 1) + bz = "bz" if dynamic else 1 + channel = "channel" if dynamic else 2 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ADCB"}], + "inputs": [ + {"name": "inp_0", "shape": [bz, channel, 3, 4], "dtype": "float32", "layout": "ADCB"} + ], "outputs": [ - {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": "float32", "layout": "ABCD"} + { + "name": "permute_dims", + "shape": [bz, 4, 3, channel], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 2, "input": 1, "permute_dims": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, channel, 3, 4], "float32")] verify_model(Permute(), input_info, expected) -def test_reshape(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_reshape(dynamic): """test graph builder for reshape""" class Reshape(Module): def forward(self, x): - return x.reshape(2, 12) + return x.reshape(-1, 12) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [ + { + "name": "reshape", + "shape": ["MUL_2" if dynamic else 2, 12], + "dtype": "float32", + "layout": "", + } + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Reshape(), input_info, expected) -def test_transpose(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_transpose(dynamic): """test graph builder for transpose""" class Transpose(Module): def forward(self, x): return x.transpose(1, 3) + bz = "bz" if dynamic else 1 + hidden = "hidden" if dynamic else 4 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ADCB"}], + "inputs": [ + {"name": "inp_0", "shape": [bz, 2, 3, hidden], "dtype": "float32", "layout": "ADCB"} + ], "outputs": [ - {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": "float32", "layout": "ABCD"} + { + "name": "permute_dims", + "shape": [bz, hidden, 3, 2], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 2, "input": 1, "permute_dims": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, hidden], "float32")] verify_model(Transpose(), input_info, expected) -def test_view(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_view(dynamic): """test graph builder for view""" class View(Module): def forward(self, x): - return x.view(2, 12) + return x.view(-1, 12) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [ + { + "name": "reshape", + "shape": ["MUL_2" if dynamic else 2, 12], + "dtype": "float32", + "layout": "", + } + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(View(), input_info, expected) -def test_keep_params(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_keep_params(dynamic): """test graph builder for keep_params""" class Conv2D1(Module): @@ -1784,228 +2082,271 @@ def __init__(self): def forward(self, data): return self.conv(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "conv2d", - "shape": [1, 6, 4, 4], + "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")], expected) + verify_model(Conv2D1(), [([bz, 3, 10, 10], "float32")], expected) -def test_unwrap_unit_return_tuple(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unwrap_unit_return_tuple(dynamic): """test graph builder for unwrap_unit_return_tuple""" class Identity(Module): def forward(self, x): return (x,) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tuple", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "tuple", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Identity(), [([256, 256], "float32")], expected) + verify_model(Identity(), [([bz, 256], "float32")], expected) -def test_no_bind_return_tuple(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_no_bind_return_tuple(dynamic): """test graph builder for no_bind_return_tuple""" class Identity(Module): def forward(self, x, y): return (x, y) + bz_x = "bz" if dynamic else 1 + bz_y = "bz" if dynamic else 2 expected = { "inputs": [ - {"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [256, 256], "dtype": "float32", "layout": ""}, + {"name": "inp_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, ], "outputs": [ - {"name": "tuple_0", "shape": [256, 256], "dtype": "float32", "layout": ""}, - {"name": "tuple_1", "shape": [256, 256], "dtype": "float32", "layout": ""}, + {"name": "tuple_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, + {"name": "tuple_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, ], "nodes": {"total": 3, "input": 2, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([256, 256], "float32"), ([256, 256], "float32")] + input_info = [([bz_x, 256], "float32"), ([bz_y, 256], "float32")] verify_model(Identity(), input_info, expected) -def test_argmax(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_argmax(dynamic): """test graph builder for argmax""" class Argmax1(Module): def forward(self, data): return torch.argmax(data, dim=-1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [256], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmax": 1}, - } - class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmax", "shape": [bz], "dtype": "int64", "layout": ""}], + "nodes": {"total": 2, "input": 1, "argmax": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [256, 1], "dtype": "int64", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmax", "shape": [bz, 1], "dtype": "int64", "layout": ""}], "nodes": {"total": 2, "input": 1, "argmax": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Argmax1(), [([256, 256], "float32")], expected1) - verify_model(Argmax2(), [([256, 256], "float32")], expected2) + verify_model(Argmax1(), [([bz, 256], "float32")], expected1) + verify_model(Argmax2(), [([bz, 256], "float32")], expected2) -def test_argmin(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_argmin(dynamic): """test graph builder for argmin""" class Argmin1(Module): def forward(self, data): return torch.argmin(data) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmin": 1}, - } - class Argmin2(Module): def forward(self, data): return torch.argmin(data, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}], + "nodes": {"total": 2, "input": 1, "argmin": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64", "layout": ""}], "nodes": {"total": 2, "input": 1, "argmin": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Argmin1(), [([256, 256], "float32")], expected1) - verify_model(Argmin2(), [([256, 256], "float32")], expected2) + verify_model(Argmin1(), [([bz, 256], "float32")], expected1) + verify_model(Argmin2(), [([bz, 256], "float32")], expected2) -def test_to(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_to(dynamic): """test graph builder for to""" class To1(Module): def forward(self, data): return data.to(torch.float16) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "astype", "shape": [256, 256], "dtype": "float16", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - class To2(Module): def forward(self, data): return data.to("cpu") + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "astype", "shape": [bz, 256], "dtype": "float16", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "astype": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "nodes": {"total": 1, "input": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(To1(), [([256, 256], "float32")], expected1) - verify_model(To2(), [([256, 256], "float32")], expected2) + verify_model(To1(), [([bz, 256], "float32")], expected1) + verify_model(To2(), [([bz, 256], "float32")], expected2) -def test_mean(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_mean(dynamic): """test graph builder for mean""" class Mean(Module): def forward(self, data): return data.mean(-1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AN"}], - "outputs": [{"name": "mean", "shape": [256], "dtype": "float32", "layout": "A"}], - "nodes": {"total": 2, "input": 1, "mean": 1}, - } - class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "mean", "shape": [bz], "dtype": "float32", "layout": "A"}], + "nodes": {"total": 2, "input": 1, "mean": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "mean", "shape": [256, 1], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "mean", "shape": [bz, 1], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "mean": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Mean(), [([256, 256], "float32")], expected1) - verify_model(MeanKeepDim(), [([256, 256], "float32")], expected2) + verify_model(Mean(), [([bz, 256], "float32")], expected1) + verify_model(MeanKeepDim(), [([bz, 256], "float32")], expected2) -def test_rsqrt(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_rsqrt(dynamic): """test graph builder for rsqrt""" class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "rsqrt", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "rsqrt", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "rsqrt": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Rsqrt(), [([256, 256], "float32")], expected) + verify_model(Rsqrt(), [([bz, 256], "float32")], expected) -def test_neg(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_neg(dynamic): """test graph builder for neg""" class Neg(Module): def forward(self, data): return -data + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "negative", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "negative", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "negative": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Neg(), [([256, 256], "float32")], expected) + verify_model(Neg(), [([bz, 256], "float32")], expected) -def test_max(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_max(dynamic): """test graph builder for max""" class Max(Module): def forward(self, x, y): return torch.max(x, y) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}, - {"name": "inp_1", "shape": [256, 256], "dtype": "float32", "layout": "AB"}, + {"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, + {"name": "inp_1", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, ], - "outputs": [{"name": "maximum", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "maximum", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 3, "input": 2, "maximum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], expected) + verify_model(Max(), [([bz, 256], "float32"), ([bz, 256], "float32")], expected) -def test_attention(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_attention(dynamic): """test graph builder for attention""" # pylint: disable=import-outside-toplevel import torch.nn.functional as F + seq = "seq" if dynamic else 128 + class Attention1(Module): def forward(self, q_data, k_data, v_data): return F.scaled_dot_product_attention(q_data, k_data, v_data) @@ -2016,25 +2357,27 @@ def forward(self, q_data, k_data, v_data): expected1 = { "inputs": [ - {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, ], "outputs": [ { "name": "attention", - "shape": [32, 128, 8, 64], + "shape": [1, seq, 8, 64], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 4, "input": 3, "msc.attention": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} input_info = [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), ] verify_model(Attention1(), input_info, expected1) verify_model(Attention2(), input_info, expected1) @@ -2045,28 +2388,31 @@ def forward(self, q_data, k_data, v_data, mask): expected2 = { "inputs": [ - {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_3", "shape": [32, 8, 128, 128], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_3", "shape": [1, 8, seq, seq], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ { "name": "attention_bias", - "shape": [32, 128, 8, 64], + "shape": [1, seq, 8, 64], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 5, "input": 4, "msc.attention": 1}, } + if dynamic: + expected2["prims"] = {"total": 1, "shape": 1} + verify_model( Attention3(), [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 128], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, seq], "float32"), ], expected2, ) diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py index 149041959416..ddc70243887b 100644 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -37,7 +37,7 @@ def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1 path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") return { - "workspace": msc_utils.msc_dir(path), + "workspace": msc_utils.msc_dir(path, keep_history=False), "verbose": "critical", "model_type": model_type, "inputs": inputs, @@ -161,7 +161,7 @@ def test_tvm_pipeline(dynamic): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1, @@ -217,7 +217,7 @@ def test_torch_pipeline(dynamic): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1, diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index 55fc9dd43e4f..031572a98e4a 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -84,13 +84,15 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): torch_model = _get_torch_model("resnet50", training) if torch_model: path = "test_runner_torch_{}_{}".format(runner_cls.__name__, device) - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("critical", log_path) input_info = [([1, 3, 224, 224], "float32")] datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(d) for d in datas] graph_model = fx.symbolic_trace(torch_model) + if training: + input_info = [([tvm.tir.Var("bz", "int64"), 3, 224, 224], "float32")] with torch.no_grad(): golden = torch_model(*torch_datas) mod = from_fx(graph_model, input_info) @@ -103,34 +105,34 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol) -def test_tvm_runner_cpu(): +@pytest.mark.parametrize("training", [True, False]) +def test_tvm_runner_cpu(training): """Test runner for tvm on cpu""" - for training in [True, False]: - _test_from_torch(TVMRunner, "cpu", training=training) + _test_from_torch(TVMRunner, "cpu", training=training) @tvm.testing.requires_cuda -def test_tvm_runner_cuda(): +@pytest.mark.parametrize("training", [True, False]) +def test_tvm_runner_cuda(training): """Test runner for tvm on cuda""" - for training in [True, False]: - _test_from_torch(TVMRunner, "cuda", training=training) + _test_from_torch(TVMRunner, "cuda", training=training) -def test_torch_runner_cpu(): +@pytest.mark.parametrize("training", [True, False]) +def test_torch_runner_cpu(training): """Test runner for torch on cpu""" - for training in [True, False]: - _test_from_torch(TorchRunner, "cpu", training=training) + _test_from_torch(TorchRunner, "cpu", training=training) @tvm.testing.requires_cuda -def test_torch_runner_cuda(): +@pytest.mark.parametrize("training", [True, False]) +def test_torch_runner_cuda(training): """Test runner for torch on cuda""" - for training in [True, False]: - _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) + _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) @requires_tensorrt @@ -146,7 +148,7 @@ def test_tensorflow_runner(): tf_graph, graph_def = _get_tf_graph() if tf_graph and graph_def: path = "test_runner_tf" - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("critical", log_path) data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 22354bb2c131..ac6f2d6c6f74 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -47,7 +47,7 @@ def _get_config( path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) return { - "workspace": msc_utils.msc_dir(path), + "workspace": msc_utils.msc_dir(path, keep_history=False), "verbose": "critical", "model_type": model_type, "inputs": inputs, @@ -229,7 +229,7 @@ def get_model_info(compile_type): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1, From 4ab3f82669fb20d77cae47704c857ab39a577417 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 16 Sep 2024 23:13:41 +0900 Subject: [PATCH 146/202] [Relax][PyTorch] Cleanup Tensor Manipulation and Creation op converters (#17376) * cleanup `_cat()` * cleanup `_cumsum()` * cleanup `_expand()` * cleanup `_flatten()` * cleanup `_permute()` * cleanup `_repeat()` * cleanup `_reshape()` * cleanup `_size()` * cleanup `_split()` * cleanup `_squeeze()` * cleanup `_tile()` * cleanup `_transpose()` * cleanup `chunk()` * cleanup `_arange()` * cleanup `_empty()` * cleanup `_inplace_fill()` * cleanup `_full()` * cleanup `_index_select()` * cleanup `_inplace_masked_fill()` * cleanup `_masked_fill()` * cleanup `_new_ones()` * cleanup `_ones()` * cleanup `_tensor()` * `_inplace_tril_triu()` is an unary op * `_batch_norm_2d()` is a nn ops * `_interpolate()` is a nn ops * `_cross_entropy()` is a nn ops * chore * fix tensor size --- .../tvm/relax/frontend/torch/fx_translator.py | 755 +++++++++--------- 1 file changed, 358 insertions(+), 397 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4dc49d20ff36..983bce0255d9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -212,6 +212,20 @@ def _softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + def _tril_triu(self, op: Callable) -> Callable: from torch import fx @@ -356,6 +370,29 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + running_mean = self._convert_torch_tensor_to_relax(module.running_mean) + running_var = self._convert_torch_tensor_to_relax(module.running_var) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _conv1d_transpose_impl( self, x: relax.Expr, @@ -683,6 +720,40 @@ def _conv3d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _cross_entropy(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + weights = self.env.get(node.kwargs["weight"], None) + reduction = node.kwargs["reduction"] + ignore_index = node.kwargs["ignore_index"] + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + + def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + module = self.named_modules[node.target] + + weights = module.weight + if weights is not None: + if weights in self.params: + weights = self.params[weights] + else: + weights = relax.const(weights.numpy(), preds.struct_info.dtype) + + reduction = module.reduction + ignore_index = module.ignore_index + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -740,6 +811,80 @@ def _group_norm_module(self, node: fx.Node) -> relax.Var: ) ) + def _interpolate(self, node: fx.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = ( + node.args[1] + if len(node.args) > 1 + else (node.kwargs["size"] if "size" in node.kwargs else None) + ) + scale_factor = ( + node.args[2] + if len(node.args) > 2 + else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) + ) + method = ( + node.args[3] + if len(node.args) > 3 + else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest") + ) + align_corners = ( + node.args[4] + if len(node.args) > 4 + else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) + ) + recompute_scale_factor = ( + node.args[5] + if len(node.args) > 5 + else ( + node.kwargs["recompute_scale_factor"] + if "recompute_scale_factor" in node.kwargs + else None + ) + ) + antialias = ( + node.args[6] + if len(node.args) > 6 + else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + ) + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, tuple): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore @@ -913,230 +1058,106 @@ def convert(node: fx.Node): return convert - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + ########## Manipulation ########## - def _to(self, node: fx.Node) -> relax.Var: - import torch + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _type(self, node: fx.Node) -> relax.Var: + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - ########## Creation ########## + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") - def _arange(self, node: fx.Node) -> relax.Var: - import torch + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + sizes = args[1:] if len(args) > 2 else args[1] + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(sizes): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] + def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) + end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) + return self._flatten_impl(x, start_dim, end_dim) - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + def _flatten_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + return self._flatten_impl(x, start_dim, end_dim) - def _empty(self, node: fx.Node) -> relax.Var: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.permute_dims(x, dims)) - def _tensor(self, node: fx.Node) -> relax.Var: - dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None - if isinstance(node.args[0], float): - return relax.const(node.args[0], dtype if dtype is not None else "float32") - elif isinstance(node.args[0], int): - return relax.const(node.args[0], dtype if dtype is not None else "int64") - raise ValueError("torch.tensor with value not a float or int is not accepted") + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _inplace_tril_triu(self, op: Callable) -> Callable: - from torch import fx + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - - mutated = self.block_builder.emit(op(x, k)) - self.env[node.args[0]] = mutated - return mutated - - return convert - - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - - ########## Manipulation ########## - - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - def _expand(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(args[1:]): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.reshape(x, dims)) - def _flatten(self, node: fx.Node) -> relax.Var: + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - start_dim = module.start_dim - end_dim = module.end_dim - else: - start_dim = node.args[1] if len(node.args) >= 2 else 0 - end_dim = node.args[2] if len(node.args) == 3 else -1 shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] - ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) if isinstance(split_size, (list, tuple)): n_section = [] for s in split_size[:-1]: @@ -1146,17 +1167,18 @@ def _split(self, node: fx.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _chunk(self, node: fx.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - chunks = node.args[1] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + return self.block_builder.emit(relax.op.squeeze(x, dim)) - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 2: - dim = node.args[2] - else: - dim = 0 - return self.block_builder.emit(relax.op.split(x, chunks, dim)) + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) @@ -1164,50 +1186,80 @@ def _transpose(self, node: fx.Node) -> relax.Var: full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - return self.block_builder.emit(relax.op.squeeze(x, dim)) + ########## Creation ########## - def _repeat(self, node: fx.Node) -> relax.Var: + def _arange(self, node: fx.Node) -> relax.Var: import torch # type: ignore - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = self._convert_data_type(torch.get_default_dtype()) else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + def _empty(self, node: fx.Node) -> relax.Var: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1215,14 +1267,6 @@ def _index_select(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.take(x, index, dim)) - def _masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - return self.block_builder.emit(relax.op.where(mask, values, x)) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] @@ -1233,168 +1277,79 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output - ########## Neural Network ########## - - def _softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - dim = module.dim - else: - nargs = len(node.args) - dim = node.args[1] if nargs > 1 else node.kwargs["dim"] - assert dim is not None - return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - - def _batch_norm_2d(self, node: fx.Node) -> relax.Var: + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = self.params[module.bias] - running_mean = self._convert_torch_tensor_to_relax(module.running_mean) - running_var = self._convert_torch_tensor_to_relax(module.running_var) - eps = module.eps + mask = self.env[node.args[1]] + rx_value = relax.const(node.args[2]) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) - res_tuple = self.block_builder.emit( - relax.op.nn.batch_norm( - x, - weight, - bias, - running_mean, - running_var, - axis=1, - epsilon=eps, + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, ) ) - return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _ones(self, node: fx.Node) -> relax.Var: + import torch - def _interpolate(self, node: fx.Node) -> relax.Var: - # torch.nn.functional.interpolate( - # input, size=None, scale_factor=None, mode='nearest', align_corners=None, - # recompute_scale_factor=None, antialias=False) - # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout - # it basically replicates the implementation in tvm.relay.frontend.pytorch - data = self.env[node.args[0]] - size = ( - node.args[1] - if len(node.args) > 1 - else (node.kwargs["size"] if "size" in node.kwargs else None) - ) - scale_factor = ( - node.args[2] - if len(node.args) > 2 - else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) - ) - method = ( - node.args[3] - if len(node.args) > 3 - else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest") - ) - align_corners = ( - node.args[4] - if len(node.args) > 4 - else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) - ) - recompute_scale_factor = ( - node.args[5] - if len(node.args) > 5 - else ( - node.kwargs["recompute_scale_factor"] - if "recompute_scale_factor" in node.kwargs - else None - ) - ) - antialias = ( - node.args[6] - if len(node.args) > 6 - else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env ) - - assert recompute_scale_factor is None - assert antialias is False - - if size is None: - shape = self.shape_of(data) - assert isinstance(shape, relax.ShapeExpr) - if isinstance(scale_factor, tuple): - assert len(scale_factor) == len(shape) - 2 - size = tuple( - int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) - ) - else: - size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) - - if method.startswith("nearest"): - method = "nearest_neighbor" - elif method[0:2] == "bi": - method = method[2:] - - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" - else: - coord_trans = "half_pixel" - return self.block_builder.emit( - relax.op.image.resize2d( - data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + relax.op.full( + size, + relax.const(1, dtype), + dtype, ) ) - def _cross_entropy(self, node: fx.Node) -> relax.Expr: - preds = self.env[node.args[0]] - targets = self.env[node.args[1]] - - # functional.cross_entropy - if node.target not in self.named_modules: - weights = node.kwargs["weight"] - if weights is not None: - weights = self.env[weights] - reduction = node.kwargs["reduction"] - ignore_index = node.kwargs["ignore_index"] - - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + def _tensor(self, node: fx.Node) -> relax.Var: + dtype = node.kwargs.get("dtype", None) + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") - module = self.named_modules[node.target] + ########## DataType ########## - weights = module.weight - if weights is not None: - if weights in self.params: - weights = self.params[weights] - else: - weights = relax.const(weights.numpy(), preds.struct_info.dtype) - reduction = module.reduction - ignore_index = module.ignore_index + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - ########## Others ########## + def _to(self, node: fx.Node) -> relax.Var: + import torch - def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] - shape = self.shape_of(x) - idx = node.args[1] - return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x - def _size(self, node: fx.Node) -> relax.Expr: + def _type(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - shape = self.shape_of(x) - if len(node.args) == 1: - assert isinstance(shape, relax.ShapeExpr) - return shape - assert len(node.args) == 2 - idx = node.args[1] - return self.shape_of(x)[idx].value + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + ########## Others ########## def _getattr(self, node: fx.Node) -> relax.Var: if isinstance(self.env[node.args[0]], relax.Expr): @@ -1485,6 +1440,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: else: assert False + def _sym_size_int(self, node: fx.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + idx = node.args[1] + return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + def create_convert_map(self): import operator from torch import nn @@ -1511,20 +1472,20 @@ def create_convert_map(self): # neural network nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AvgPool2d: self._avg_pool2d_module, - nn.BatchNorm2d: self._batch_norm_2d, + nn.BatchNorm2d: self._batch_norm_2d_module, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose_module, - nn.CrossEntropyLoss: self._cross_entropy, + nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear_module, nn.MaxPool2d: self._max_pool2d_module, nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation - nn.Flatten: self._flatten, + nn.Flatten: self._flatten_module, ## call_function and call_method # unary "acos": self._unary_op(relax.op.acos), @@ -1603,6 +1564,7 @@ def create_convert_map(self): "argmin": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat": self._cat, + "chunk": self._chunk, "concat": self._cat, "contiguous": lambda node: self.env[node.args[0]], "cumsum": self._cumsum, @@ -1622,7 +1584,6 @@ def create_convert_map(self): "view": self._reshape, # tensor creation "arange": self._arange, - "chunk": self._chunk, "empty": self._empty, "fill_": self._inplace_fill, "full": self._full, @@ -1632,11 +1593,11 @@ def create_convert_map(self): "new_ones": self._new_ones, "ones": self._ones, "tensor": self._tensor, - "to": self._to, # datatype "astype": self._type, "float": self._float, "half": self._half, + "to": self._to, "type": self._type, # other "getattr": self._getattr, From a355a5247c8c4b3b2cec65260cffb2668edc7741 Mon Sep 17 00:00:00 2001 From: Arnout Engelen Date: Tue, 17 Sep 2024 03:09:10 +0200 Subject: [PATCH 147/202] [DOCS] Link to project-specific security page (#17378) Make the project-specific information more prominent. This project-specific page already links to the general ASF information at https://apache.org/security/ --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 8c71f5eb1d55..12039ebb2c8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -627,7 +627,7 @@ def force_gc(gallery_conf, fname): ("Apache Homepage", "https://apache.org/"), ("License", "https://www.apache.org/licenses/"), ("Sponsorship", "https://www.apache.org/foundation/sponsorship.html"), - ("Security", "https://www.apache.org/security/"), + ("Security", "https://tvm.apache.org/docs/reference/security.html"), ("Thanks", "https://www.apache.org/foundation/thanks.html"), ("Events", "https://www.apache.org/events/current-event"), ], From d3900bed871b2fd54b55039fa4b41fe14b4c33e3 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 17 Sep 2024 10:09:20 +0900 Subject: [PATCH 148/202] [CI] Disable NNPACK build and fix error on Android SDK installaion (#17337) * disable nnpack on ci * fix android sdk installation error * port from https://github.com/octoml/relax/pull/38 * remove androidsdk from ci image --- cmake/modules/contrib/TFLite.cmake | 4 ++++ docker/Dockerfile.ci_adreno | 5 ----- docker/Dockerfile.ci_cpu | 8 -------- docker/Dockerfile.ci_gpu | 4 ---- docker/Dockerfile.ci_hexagon | 6 ------ docker/Dockerfile.demo_vitis_ai | 4 ---- docker/install/ubuntu_install_androidsdk.sh | 14 +++++++------- docker/install/ubuntu_install_java.sh | 6 +++--- tests/scripts/task_config_build_cpu.sh | 2 -- tests/scripts/task_config_build_gpu.sh | 2 -- 10 files changed, 14 insertions(+), 41 deletions(-) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake index b8d6a0daff19..255dc5fde780 100644 --- a/cmake/modules/contrib/TFLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -39,6 +39,10 @@ if(NOT USE_TFLITE STREQUAL "OFF") endif() find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) file(GLOB_RECURSE TFLITE_DEPS "${USE_TFLITE}/*.a") + # the order of the next libs are important for correct build + list(REMOVE_ITEM TFLITE_DEPS "${USE_TFLITE}/_deps/clog-build/libclog.a" "${USE_TFLITE}/_deps/cpuinfo-build/libcpuinfo.a") + list(APPEND TFLITE_DEPS "${USE_TFLITE}/_deps/cpuinfo-build/libcpuinfo.a") + list(APPEND TFLITE_DEPS "${USE_TFLITE}/_deps/clog-build/libclog.a") list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_DEPS}) diff --git a/docker/Dockerfile.ci_adreno b/docker/Dockerfile.ci_adreno index 961977c54286..30e095b27aac 100644 --- a/docker/Dockerfile.ci_adreno +++ b/docker/Dockerfile.ci_adreno @@ -20,11 +20,6 @@ FROM tlcpack/ci-gpu COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear -# Android SDK -COPY install/ubuntu_install_androidsdk.sh /install/ubuntu_install_androidsdk.sh -RUN bash /install/ubuntu_install_androidsdk.sh 25.2.9519653 3.22.1 33.0.2 33 -ENV PATH /opt/android-sdk-linux/platform-tools:$PATH - # Clang tool for CLML source codegen RUN apt-get update && apt-install-and-clear -y clang-format-15 diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index ae088f5c9e63..17344f7dac22 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -77,10 +77,6 @@ COPY install/ubuntu_install_golang.sh /install/ubuntu_install_golang.sh RUN bash /install/ubuntu_install_golang.sh ENV PATH $PATH:/usr/lib/go-1.18/bin -# NNPACK deps -COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh -RUN bash /install/ubuntu_install_nnpack.sh - # ANTLR deps COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh @@ -129,10 +125,6 @@ RUN bash /install/ubuntu_install_ethosn_driver_stack.sh COPY install/ubuntu_install_vitis_ai_packages_ci.sh /install/ubuntu_install_vitis_ai_packages_ci.sh RUN bash /install/ubuntu_install_vitis_ai_packages_ci.sh -# Android SDK -COPY install/ubuntu_install_androidsdk.sh /install/ubuntu_install_androidsdk.sh -RUN bash /install/ubuntu_install_androidsdk.sh - # PaddlePaddle deps COPY install/ubuntu_install_paddle.sh /install/ubuntu_install_paddle.sh RUN bash /install/ubuntu_install_paddle.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index acb0310a41e2..8d11882098fb 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -133,10 +133,6 @@ RUN bash /install/ubuntu_install_wasmtime.sh COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh -# NNPACK deps -COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh -RUN bash /install/ubuntu_install_nnpack.sh - # BYODT deps COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh RUN bash /install/ubuntu_install_universal.sh diff --git a/docker/Dockerfile.ci_hexagon b/docker/Dockerfile.ci_hexagon index 3b4c58ef43c9..1855e3a9c231 100644 --- a/docker/Dockerfile.ci_hexagon +++ b/docker/Dockerfile.ci_hexagon @@ -58,12 +58,6 @@ RUN bash /install/ubuntu_install_python_package.sh COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh -# Android SDK -COPY install/ubuntu_install_androidsdk.sh /install/ubuntu_install_androidsdk.sh -RUN bash /install/ubuntu_install_androidsdk.sh -ENV ANDROID_HOME=/opt/android-sdk-linux -ENV PATH /opt/android-sdk-linux/platform-tools:$PATH - # Hexagon COPY install/ubuntu_install_hexagon.sh /install/ubuntu_install_hexagon.sh RUN bash /install/ubuntu_install_hexagon.sh diff --git a/docker/Dockerfile.demo_vitis_ai b/docker/Dockerfile.demo_vitis_ai index b82076dbdf9c..01b0b494bd9e 100644 --- a/docker/Dockerfile.demo_vitis_ai +++ b/docker/Dockerfile.demo_vitis_ai @@ -45,10 +45,6 @@ RUN bash /install/ubuntu_install_python_package.sh COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh RUN bash /install/ubuntu_install_llvm.sh -# NNPACK deps -COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh -RUN bash /install/ubuntu_install_nnpack.sh - ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin # ANTLR deps diff --git a/docker/install/ubuntu_install_androidsdk.sh b/docker/install/ubuntu_install_androidsdk.sh index 5e7278c5d631..193a02745f3a 100755 --- a/docker/install/ubuntu_install_androidsdk.sh +++ b/docker/install/ubuntu_install_androidsdk.sh @@ -25,6 +25,8 @@ ANDROID_HOME=/opt/android-sdk-linux ASDKTOOLS_HOME=/opt/android-sdk-tools ASDKTOOLS_VERSION=3859397 ASDKTOOLS_SHA256=444e22ce8ca0f67353bda4b85175ed3731cae3ffa695ca18119cbacef1c1bea0 +COMMANDLINETOOLS_VERSION=11076708 +COMMANDLINETOOLS_SHA256=2d2d50857e4eb553af5a6dc3ad507a17adf43d115264b1afc116f95c92e5e258 ANDROID_NDK_VERSION=21.3.6528147 CMAKE_VERSION=3.6.4111459 @@ -52,11 +54,11 @@ echo "Cmake Version: ${CMAKE_VERSION}" echo "Build Tools: ${BUILD_TOOLS_VERSION}" echo "Android Platform: ${ANDROID_PLATFORM}" -wget -q http://dl.google.com/android/repository/sdk-tools-linux-${ASDKTOOLS_VERSION}.zip -O sdk-tools-linux.zip -echo "${ASDKTOOLS_SHA256} *sdk-tools-linux.zip" | sha256sum --check - -unzip sdk-tools-linux.zip -rm sdk-tools-linux.zip -mv tools "${ASDKTOOLS_HOME}/" +wget -q https://dl.google.com/android/repository/commandlinetools-linux-${COMMANDLINETOOLS_VERSION}_latest.zip -O commandlinetools-linux.zip +echo "${COMMANDLINETOOLS_SHA256} commandlinetools-linux.zip" | sha256sum --check - +unzip commandlinetools-linux.zip +rm commandlinetools-linux.zip +mv cmdline-tools/ "${ASDKTOOLS_HOME}/" # The following popular fix makes sdkmanager honour $http_proxy variables mv ${ASDKTOOLS_HOME}/bin/sdkmanager ${ASDKTOOLS_HOME}/bin/sdkmanager-vanilla cat >${ASDKTOOLS_HOME}/bin/sdkmanager <<"EOF" @@ -90,8 +92,6 @@ extras;google;market_apk_expansion extras;google;market_licensing extras;google;simulators extras;google;webdriver -extras;m2repository;com;android;support;constraint;constraint-layout;1.0.2 -extras;m2repository;com;android;support;constraint;constraint-layout-solver;1.0.2 platforms;android-26 platforms;android-${ANDROID_PLATFORM} tools diff --git a/docker/install/ubuntu_install_java.sh b/docker/install/ubuntu_install_java.sh index 5556f0d8fed5..c4a8c5f9acb5 100755 --- a/docker/install/ubuntu_install_java.sh +++ b/docker/install/ubuntu_install_java.sh @@ -20,7 +20,7 @@ set -o errexit -o nounset set -o pipefail apt-get update -apt-install-and-clear -y openjdk-8-jdk maven +apt-install-and-clear -y openjdk-17-jdk maven arch=$(uname -m) jre_arch="unknown" case $arch in @@ -36,8 +36,8 @@ case $arch in ;; esac -if [ ! -d "/usr/lib/jvm/java-8-openjdk-$jre_arch/jre" ]; then +if [ ! -d "/usr/lib/jvm/java-17-openjdk-$jre_arch" ]; then echo "error: missing openjdk for $jre_arch" >&2 exit 1 fi -echo "export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-$jre_arch/jre" >> /etc/profile +echo "export JAVA_HOME=/usr/lib/jvm/java-17-openjdk-$jre_arch" >> /etc/profile diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index f509aad30627..c97321e538bd 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -30,8 +30,6 @@ echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_DNNL ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_LLVM \"/usr/bin/llvm-config-17 --link-static\"\) >> config.cmake -echo set\(USE_NNPACK ON\) >> config.cmake -echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake echo set\(USE_ANTLR ON\) >> config.cmake echo set\(CMAKE_CXX_FLAGS \"-Werror -Wno-error=range-loop-construct\"\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index e68e646ce178..03f90c5ad4a1 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -33,8 +33,6 @@ echo set\(USE_OPENCL_GTEST \"/googletest\"\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_LLVM \"/usr/bin/llvm-config-15 --link-static\"\) >> config.cmake -echo set\(USE_NNPACK ON\) >> config.cmake -echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_SORT ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake From 4692b9591d3d9992473f733d96c1b14eb00cd7a3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 16 Sep 2024 21:12:20 -0400 Subject: [PATCH 149/202] [DOCS] Update document to include security model of RPC server (#17377) This PR update the documents to include the security model of the RPC server. --- docs/reference/security.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/security.rst b/docs/reference/security.rst index c2603dd33ee5..6093063bd98e 100644 --- a/docs/reference/security.rst +++ b/docs/reference/security.rst @@ -34,10 +34,16 @@ The private security mailing address is: `security@apache.org `_. -Considerations +Security Model -------------- The default binary generated by TVM only relies on a minimum runtime API. The runtime depends on a limited set of system calls(e.g. malloc) in the system library. + +TVM RPC server assumes that the user is trusted and needs to be used in a trusted network environment +and encrypted channels. It allows writings of arbitrary files into the server and provide +full remote code execution capabilities to anyone who can access this API. + + AutoTVM data exchange between the tracker, server and client are in plain-text. It is recommended to use them under trusted networking environment or encrypted channels. From 1435ddb118ce4fc6b87c07804e554c2e945053c9 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 17 Sep 2024 22:06:38 +0800 Subject: [PATCH 150/202] [Doc] Relax Deep Dive (#17380) * [Doc] Relax Deep Dive Similar as TensorIR Deep Dive, we also have Relax Deep Dive. --- docs/conf.py | 7 +- docs/deep_dive/relax/abstraction.rst | 73 +++++ docs/deep_dive/relax/index.rst | 34 +++ docs/deep_dive/relax/learning.rst | 272 +++++++++++++++++ docs/deep_dive/relax/tutorials/README.txt | 2 + .../relax/tutorials/relax_creation.py | 281 ++++++++++++++++++ .../relax/tutorials/relax_transformation.py | 141 +++++++++ docs/deep_dive/tensor_ir/abstraction.rst | 1 - docs/deep_dive/tensor_ir/index.rst | 6 +- .../{creation.py => tir_creation.py} | 0 ...ransformation.py => tir_transformation.py} | 0 docs/index.rst | 1 + 12 files changed, 811 insertions(+), 7 deletions(-) create mode 100644 docs/deep_dive/relax/abstraction.rst create mode 100644 docs/deep_dive/relax/index.rst create mode 100644 docs/deep_dive/relax/learning.rst create mode 100644 docs/deep_dive/relax/tutorials/README.txt create mode 100644 docs/deep_dive/relax/tutorials/relax_creation.py create mode 100644 docs/deep_dive/relax/tutorials/relax_transformation.py rename docs/deep_dive/tensor_ir/tutorials/{creation.py => tir_creation.py} (100%) rename docs/deep_dive/tensor_ir/tutorials/{transformation.py => tir_transformation.py} (100%) diff --git a/docs/conf.py b/docs/conf.py index 12039ebb2c8f..acc03161e559 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -424,6 +424,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder tvm_path.joinpath("docs", "get_started", "tutorials"), tvm_path.joinpath("docs", "how_to", "tutorials"), + tvm_path.joinpath("docs", "deep_dive", "relax", "tutorials"), tvm_path.joinpath("docs", "deep_dive", "tensor_ir", "tutorials"), ] @@ -443,6 +444,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder "get_started/tutorials/", "how_to/tutorials/", + "deep_dive/relax/tutorials/", "deep_dive/tensor_ir/tutorials/", ] @@ -598,10 +600,10 @@ def force_gc(gallery_conf, fname): ## Setup header and other configs import tlcpack_sphinx_addon -footer_copyright = "© 2023 Apache Software Foundation | All rights reserved" +footer_copyright = "© 2024 Apache Software Foundation | All rights reserved" footer_note = " ".join( """ -Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, +Copyright © 2024 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.""".split( "\n" @@ -614,7 +616,6 @@ def force_gc(gallery_conf, fname): header_links = [ ("Community", "https://tvm.apache.org/community"), ("Download", "https://tvm.apache.org/download"), - ("VTA", "https://tvm.apache.org/vta"), ("Blog", "https://tvm.apache.org/blog"), ("Docs", "https://tvm.apache.org/docs"), ("Conference", "https://tvmconf.org"), diff --git a/docs/deep_dive/relax/abstraction.rst b/docs/deep_dive/relax/abstraction.rst new file mode 100644 index 000000000000..2b9ee8b5d741 --- /dev/null +++ b/docs/deep_dive/relax/abstraction.rst @@ -0,0 +1,73 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _relax-abstraction: + +Graph Abstraction for ML Models +------------------------------- +Graph abstraction is a key technique used in machine learning (ML) compilers +to represent and reason about the structure and data flow of ML models. By +abstracting the model into a graph representation, the compiler can perform +various optimizations to improve performance and efficiency. This tutorial will +cover the basics of graph abstraction, its key elements of Relax IR, and how it enables optimization in ML compilers. + +What is Graph Abstraction? +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Graph abstraction is the process of representing an ML model as a directed graph, +where the nodes represent computational operations (e.g., matrix multiplication, +convolution) and the edges represent the flow of data between these operations. +This abstraction allows the compiler to analyze the dependencies and +relationships between different parts of the model. + +.. code:: python + + from tvm.script import relax as R + + @R.function + def main( + x: R.Tensor((1, 784), dtype="float32"), + weight: R.Tensor((784, 256), dtype="float32"), + bias: R.Tensor((256,), dtype="float32"), + ) -> R.Tensor((1, 256), dtype="float32"): + with R.dataflow(): + lv0 = R.matmul(x, weight) + lv1 = R.add(lv0, bias) + gv = R.nn.relu(lv1) + R.output(gv) + return gv + +Key Features of Relax +~~~~~~~~~~~~~~~~~~~~~ +Relax, the graph representation utilized in Apache TVM's Unity strategy, +facilitates end-to-end optimization of ML models through several crucial +features: + +- **First-class symbolic shape**: Relax employs symbolic shapes to represent + tensor dimensions, enabling global tracking of dynamic shape relationships + across tensor operators and function calls. + +- **Multi-level abstractions**: Relax supports cross-level abstractions, from + high-level neural network layers to low-level tensor operations, enabling + optimizations that span different hierarchies within the model. + +- **Composable transformations**: Relax offers a framework for composable + transformations that can be selectively applied to different model components. + This includes capabilities such as partial lowering and partial specialization, + providing flexible customization and optimization options. + +These features collectively empower Relax to offer a powerful and adaptable approach +to ML model optimization within the Apache TVM ecosystem. diff --git a/docs/deep_dive/relax/index.rst b/docs/deep_dive/relax/index.rst new file mode 100644 index 000000000000..f891eb2793ec --- /dev/null +++ b/docs/deep_dive/relax/index.rst @@ -0,0 +1,34 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _relax: + +Relax +===== +Relax is a high-level abstraction for graph optimization and transformation in Apache TVM stack. +Additionally, Apache TVM combine Relax and TensorIR together as a unity strategy for cross-level +optimization. Hence, Relax is usually working closely with TensorIR for representing and optimizing +the whole IRModule + + +.. toctree:: + :maxdepth: 2 + + abstraction + learning + tutorials/relax_creation + tutorials/relax_transformation diff --git a/docs/deep_dive/relax/learning.rst b/docs/deep_dive/relax/learning.rst new file mode 100644 index 000000000000..702b0e0a9f29 --- /dev/null +++ b/docs/deep_dive/relax/learning.rst @@ -0,0 +1,272 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _relax-learning: + +Understand Relax Abstraction +============================ +Relax is a graph abstraction used in Apache TVM Unity strategy, which +helps to end-to-end optimize ML models. The principal objective of Relax +is to depict the structure and data flow of ML models, including the +dependencies and relationships between different parts of the model, as +well as how to execute the model on hardware. + +End to End Model Execution +-------------------------- + +In this chapter, we will use the following model as an example. This is +a two-layer neural network that consists of two linear operations with +relu activation. + +.. image:: https://mlc.ai/_images/e2e_fashionmnist_mlp_model.png + :width: 85% + :align: center + + +High-Level Operations Representation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let us begin by reviewing a Numpy implementation of the model. + +.. code:: python + + def numpy_mlp(data, w0, b0, w1, b1): + lv0 = data @ w0 + b0 + lv1 = np.maximum(lv0, 0) + lv2 = lv1 @ w1 + b1 + return lv2 + +The above example code shows the high-level array operations to perform the end-to-end model +execution. Of course, we can rewrite the above code using Relax as follows: + +.. code:: python + + from tvm.script import relax as R + + @R.function + def relax_mlp( + data: R.Tensor(("n", 784), dtype="float32"), + w0: R.Tensor((784, 128), dtype="float32"), + b0: R.Tensor((128,), dtype="float32"), + w1: R.Tensor((128, 10), dtype="float32"), + b1: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor(("n", 10), dtype="float32"): + with R.dataflow(): + lv0 = R.matmul(data, w0) + b0 + lv1 = R.nn.relu(lv0) + lv2 = R.matmul(lv1, w1) + b1 + R.output(lv2) + return lv2 + +Low-Level Integration +~~~~~~~~~~~~~~~~~~~~~ + +However, again from the pov of machine learning compilation (MLC), we would like to see +through the details under the hood of these array computations. + +For the purpose of illustrating details under the hood, we will again write examples in low-level numpy: + +We will use a loop instead of array functions when necessary to demonstrate the possible loop computations. +When possible, we always explicitly allocate arrays via numpy.empty and pass them around. +The code block below shows a low-level numpy implementation of the same model. + +.. code:: python + + def lnumpy_linear(X: np.ndarray, W: np.ndarray, B: np.ndarray, Z: np.ndarray): + n, m, K = X.shape[0], W.shape[1], X.shape[1] + Y = np.empty((n, m), dtype="float32") + for i in range(n): + for j in range(m): + for k in range(K): + if k == 0: + Y[i, j] = 0 + Y[i, j] = Y[i, j] + X[i, k] * W[k, j] + + for i in range(n): + for j in range(m): + Z[i, j] = Y[i, j] + B[j] + + + def lnumpy_relu0(X: np.ndarray, Y: np.ndarray): + n, m = X.shape + for i in range(n): + for j in range(m): + Y[i, j] = np.maximum(X[i, j], 0) + + def lnumpy_mlp(data, w0, b0, w1, b1): + n = data.shape[0] + lv0 = np.empty((n, 128), dtype="float32") + lnumpy_matmul(data, w0, b0, lv0) + + lv1 = np.empty((n, 128), dtype="float32") + lnumpy_relu(lv0, lv1) + + out = np.empty((n, 10), dtype="float32") + lnumpy_matmul(lv1, w1, b1, out) + return out + +With the low-level NumPy example in mind, now we are ready to introduce an Relax abstraction +for the end-to-end model execution. The code block below shows a TVMScript implementation of the model. + +.. code:: python + + @I.ir_module + class Module: + @T.prim_func(private=True) + def linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle): + M, N, K = T.int64(), T.int64(), T.int64() + X = T.match_buffer(x, (M, K), "float32") + W = T.match_buffer(w, (K, N), "float32") + B = T.match_buffer(b, (N,), "float32") + Z = T.match_buffer(z, (M, N), "float32") + Y = T.alloc_buffer((M, N), "float32") + for i, j, k in T.grid(M, N, K): + with T.block("Y"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[v_i, v_j] = T.float32(0.0) + Y[v_i, v_j] = Y[v_i, v_j] + X[v_i, v_k] * W[v_k, v_j] + for i, j in T.grid(M, N): + with T.block("Z"): + v_i, v_j = T.axis.remap("SS", [i, j]) + Z[v_i, v_j] = Y[v_i, v_j] + B[v_j] + + @T.prim_func(private=True) + def relu(x: T.handle, y: T.handle): + M, N = T.int64(), T.int64() + X = T.match_buffer(x, (M, N), "float32") + Y = T.match_buffer(y, (M, N), "float32") + for i, j in T.grid(M, N): + with T.block("Y"): + v_i, v_j = T.axis.remap("SS", [i, j]) + Y[v_i, v_j] = T.max(X[v_i, v_j], T.float32(0.0)) + + @R.function + def main( + x: R.Tensor(("n", 784), dtype="float32"), + w0: R.Tensor((784, 256), dtype="float32"), + b0: R.Tensor((256,), dtype="float32"), + w1: R.Tensor((256, 10), dtype="float32"), + b1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor(("n", 10), dtype="float32"): + cls = Module + n = T.int64() + with R.dataflow(): + lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32")) + lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32")) + lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32")) + R.output(lv2) + return lv2 + +The above code contains kinds of functions: the primitive tensor functions (``T.prim_func``) and a +``R.function`` (relax function). Relax function is a new type of abstraction representing +high-level neural network executions. + +Note that the above relax module natively supports symbolic shapes, see the ``"n"`` in the +tensor shapes in ``main`` function and ``M``, ``N``, ``K`` in the ``linear`` function. This is +a key feature of Relax abstraction, which enables the compiler to track dynamic shape relations +globally across tensor operators and function calls. + +Again it is helpful to see the TVMScript code and low-level numpy code side-by-side and check the +corresponding elements, and we are going to walk through each of them in detail. Since we already +learned about primitive tensor functions, we are going to focus on the high-level execution part. + +Key Elements of Relax +--------------------- +This section will introduce the key elements of Relax abstraction and how it enables optimization +in ML compilers. + +Structure Info +~~~~~~~~~~~~~~ +Structure info is a new concept in Relax that represents the type of relax expressions. It can +be ``TensorStructInfo``, ``TupleStructInfo``, etc. In the above example, we use ``TensorStructInfo`` +(short in ``R.Tensor`` in TVMScript) to represent the shape and dtype of the tensor of the inputs, +outputs, and intermediate results. + +R.call_tir +~~~~~~~~~~ +The ``R.call_tir`` function is a new abstraction in Relax that allows calling primitive tensor +functions in the same IRModule. This is a key feature of Relax that enables cross-level +abstractions, from high-level neural network layers to low-level tensor operations. +Taking one line from the above code as an example: + +.. code:: python + + lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32")) + +To explain what does ``R.call_tir`` work, let us review an equivalent low-level numpy +implementation of the operation, as follows: + +.. code:: python + + lv0 = np.empty((n, 256), dtype="float32") + lnumpy_linear(x, w0, b0, lv0) + +Specifically, ``call_tir`` allocates an output tensor res, then pass the inputs and the output +to the prim_func. After executing prim_func the result is populated in res, then we can return +the result. + +This convention is called **destination passing**, The idea is that input and output are explicitly +allocated outside and passed to the low-level primitive function. This style is commonly used +in low-level library designs, so higher-level frameworks can handle that memory allocation +decision. Note that not all tensor operations can be presented in this style (specifically, +there are operations whose output shape depends on the input). Nevertheless, in common practice, +it is usually helpful to write the low-level function in this style when possible. + +Dataflow Block +~~~~~~~~~~~~~~ +Another important element in a relax function is the R.dataflow() scope annotation. + +.. code:: python + + with R.dataflow(): + lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32")) + lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32")) + lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32")) + R.output(lv2) + +Before we talk about the dataflow block, let us first introduce the concept of **pure** and +**side-effect**. A function is **pure** or **side-effect free** if: + +- it only reads from its inputs and returns the result via its output +- it will not change other parts of the program (such as incrementing a global counter). + +For example, all ``R.call_tir`` functions are pure functions, as they only read from their inputs +and write the output to another new allocated tensor. However, the **inplace operations** are not +pure functions, in other words, they are side-effect functions, because they will change the existing +intermediate or input tensors. + +A dataflow block is a way for us to mark the computational graph regions of the program. +Specifically, within a dataflow block, all the operations need to be **side-effect free**. +Outside a dataflow block, the operations can contain side-effect. + +.. note:: + + A common question that arises is why we need to manually mark dataflow blocks instead of + automatically inferring them. There are two main reasons for this approach: + + - Automatic inference of dataflow blocks can be challenging and imprecise, particularly + when dealing with calls to packed functions (such as cuBLAS integrations). By manually + marking dataflow blocks, we enable the compiler to accurately understand and optimize + the program's dataflow. + - Many optimizations can only be applied within dataflow blocks. For instance, fusion + optimization is limited to operations within a single dataflow block. If the compiler + were to incorrectly infer dataflow boundaries, it might miss crucial optimization + opportunities, potentially impacting the program's performance. + +By allowing manual marking of dataflow blocks, we ensure that the compiler has the most +accurate information to work with, leading to more effective optimizations. diff --git a/docs/deep_dive/relax/tutorials/README.txt b/docs/deep_dive/relax/tutorials/README.txt new file mode 100644 index 000000000000..b532ae9386ec --- /dev/null +++ b/docs/deep_dive/relax/tutorials/README.txt @@ -0,0 +1,2 @@ +Deep Dive: Relax +---------------- diff --git a/docs/deep_dive/relax/tutorials/relax_creation.py b/docs/deep_dive/relax/tutorials/relax_creation.py new file mode 100644 index 000000000000..f6278e3b65b1 --- /dev/null +++ b/docs/deep_dive/relax/tutorials/relax_creation.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _relax-creation: + +Relax Creation +============== +This tutorial demonstrates how to create Relax functions and programs. +We'll cover various ways to define Relax functions, including using TVMScript, +and relax NNModule API. +""" + + +###################################################################### +# Create Relax programs using TVMScript +# ------------------------------------- +# TVMScript is a domain-specific language for representing Apache TVM's +# intermediate representation (IR). It is a Python dialect that can be used +# to define an IRModule, which contains both TensorIR and Relax functions. +# +# In this section, we will show how to define a simple MLP model with only +# high-level Relax operators using TVMScript. + +from tvm import relax, topi +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +@I.ir_module +class RelaxModule: + @R.function + def forward( + data: R.Tensor(("n", 784), dtype="float32"), + w0: R.Tensor((128, 784), dtype="float32"), + b0: R.Tensor((128,), dtype="float32"), + w1: R.Tensor((10, 128), dtype="float32"), + b1: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor(("n", 10), dtype="float32"): + with R.dataflow(): + lv0 = R.matmul(data, R.permute_dims(w0)) + b0 + lv1 = R.nn.relu(lv0) + lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1 + R.output(lv2) + return lv2 + + +RelaxModule.show() + +###################################################################### +# Relax is not only a graph-level IR, but also supports cross-level +# representation and transformation. To be specific, we can directly call +# TensorIR functions in Relax function. + + +@I.ir_module +class RelaxModuleWithTIR: + @T.prim_func + def relu(x: T.handle, y: T.handle): + n, m = T.int64(), T.int64() + X = T.match_buffer(x, (n, m), "float32") + Y = T.match_buffer(y, (n, m), "float32") + for i, j in T.grid(n, m): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + Y[vi, vj] = T.max(X[vi, vj], T.float32(0)) + + @R.function + def forward( + data: R.Tensor(("n", 784), dtype="float32"), + w0: R.Tensor((128, 784), dtype="float32"), + b0: R.Tensor((128,), dtype="float32"), + w1: R.Tensor((10, 128), dtype="float32"), + b1: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor(("n", 10), dtype="float32"): + n = T.int64() + cls = RelaxModuleWithTIR + with R.dataflow(): + lv0 = R.matmul(data, R.permute_dims(w0)) + b0 + lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32")) + lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1 + R.output(lv2) + return lv2 + + +RelaxModuleWithTIR.show() + +###################################################################### +# .. note:: +# +# You may notice that the printed output is different from the written +# TVMScript code. This is because we print the IRModule in a standard +# format, while we support syntax sugar for the input +# +# For example, we can combine multiple operators into a single line, as +# +# .. code-block:: python +# +# lv0 = R.matmul(data, R.permute_dims(w0)) + b0 +# +# However, the normalized expression requires only one operation in one +# binding. So the printed output is different from the written TVMScript code, +# as +# +# .. code-block:: python +# +# lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None) +# lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void") +# lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0) +# + +###################################################################### +# Create Relax programs using NNModule API +# ---------------------------------------- +# Besides TVMScript, we also provide a PyTorch-like API for defining neural networks. +# It is designed to be more intuitive and easier to use than TVMScript. +# +# In this section, we will show how to define the same MLP model using +# Relax NNModule API. + +from tvm.relax.frontend import nn + + +class NNModule(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 128) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +###################################################################### +# After we define the NNModule, we can export it to TVM IRModule via +# ``export_tvm``. + +mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}) +mod.show() + +###################################################################### +# We can also insert customized function calls into the NNModule, such as +# Tensor Expression(TE), TensorIR functions or other TVM packed functions. + + +@T.prim_func +def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle): + M, N, K = T.int64(), T.int64(), T.int64() + X = T.match_buffer(x, (M, K), "float32") + W = T.match_buffer(w, (N, K), "float32") + B = T.match_buffer(b, (N,), "float32") + Z = T.match_buffer(z, (M, N), "float32") + for i, j, k in T.grid(M, N, K): + with T.block("linear"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Z[vi, vj] = 0 + Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk] + for i, j in T.grid(M, N): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Z[vi, vj] = Z[vi, vj] + B[vj] + + +class NNModuleWithTIR(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + n = x.shape[0] + # We can call external functions using nn.extern + x = nn.extern( + "env.linear", + [x, self.fc1.weight, self.fc1.bias], + out=nn.Tensor.placeholder((n, 128), "float32"), + ) + # We can also call TensorIR via Tensor Expression API in TOPI + x = nn.tensor_expr_op(topi.nn.relu, "relu", [x]) + # We can also call other TVM packed functions + x = nn.tensor_ir_op( + tir_linear, + "tir_linear", + [x, self.fc2.weight, self.fc2.bias], + out=nn.Tensor.placeholder((n, 10), "float32"), + ) + return x + + +mod, params = NNModuleWithTIR().export_tvm( + {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}} +) +mod.show() + + +###################################################################### +# Create Relax programs using Block Builder API +# --------------------------------------------- +# In addition to the above APIs, we also provide a Block Builder API for +# creating Relax programs. It is a IR builder API, which is more +# low-level and widely used in TVM's internal logic, e.g writing a +# customized pass. + +bb = relax.BlockBuilder() +n = T.int64() +x = relax.Var("x", R.Tensor((n, 784), "float32")) +fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32")) +fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32")) +fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32")) +fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32")) +with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]): + with bb.dataflow(): + lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias) + lv1 = bb.emit(relax.op.nn.relu(lv0)) + gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias) + bb.emit_output(gv) + bb.emit_func_output(gv) + +mod = bb.get() +mod.show() + +###################################################################### +# Also, Block Builder API supports building cross-level IRModule with both +# Relax functions, TensorIR functions and other TVM packed functions. + +bb = relax.BlockBuilder() +with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_dps_packed( + "env.linear", + [x, fc1_weight, fc1_bias], + out_sinfo=relax.TensorStructInfo((n, 128), "float32"), + ) + ) + lv1 = bb.emit_te(topi.nn.relu, lv0) + tir_gv = bb.add_func(tir_linear, "tir_linear") + gv = bb.emit( + relax.call_tir( + tir_gv, + [lv1, fc2_weight, fc2_bias], + out_sinfo=relax.TensorStructInfo((n, 10), "float32"), + ) + ) + bb.emit_output(gv) + bb.emit_func_output(gv) +mod = bb.get() +mod.show() + +###################################################################### +# Note that the Block Builder API is not as user-friendly as the above APIs, +# but it is lowest-level API and works closely with the IR definition. We +# recommend using the above APIs for users who only want to define and +# transform a ML model. But for those who want to build more complex +# transformations, the Block Builder API is a more flexible choice. + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates how to create Relax programs using TVMScript, +# NNModule API, Block Builder API and PackedFunc API for different use cases. diff --git a/docs/deep_dive/relax/tutorials/relax_transformation.py b/docs/deep_dive/relax/tutorials/relax_transformation.py new file mode 100644 index 000000000000..01d8e4e32039 --- /dev/null +++ b/docs/deep_dive/relax/tutorials/relax_transformation.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _relax-transform: + +Transformation +-------------- +In this section, we will dive into the transformation of Relax programs. +Transformations is one of the key ingredients of the compilation flows +for optimizing and integrating with hardware backends. +""" + +###################################################################### +# Let's first create a simple Relax program as what we have done in +# the :ref:`previous section `. + +import tvm +from tvm import IRModule, relax +from tvm.relax.frontend import nn + + +class NNModule(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 128) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + return x + + +origin_mod, params = NNModule().export_tvm( + {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}} +) +origin_mod.show() + +###################################################################### +# Apply transformations +# ~~~~~~~~~~~~~~~~~~~~~ +# Passes are the main way to apply transformations to the program. +# We can apply passes to the program. As first step, let's apply +# a built-in pass ``LegalizeOps`` to lower the high-level operators +# into low-level operators. + +mod = tvm.relax.transform.LegalizeOps()(origin_mod) +mod.show() + +###################################################################### +# As we can see from the output, the high-level operators (aka ``relax.op``) in the program +# are replaced by their corresponding low-level operators (aka ``relax.call_tir``). +# +# Then let's trying to apply the operator fusion, which is a wide-used optimization technique +# in ML compilers. Note that in relax, fusion optimizations are done with the collaboration of +# a set of passes. We can apply them in a sequence. + +mod = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + ] +)(mod) +mod.show() + +###################################################################### +# As result, we can see that the ``matmul``, ``add`` and ``relu`` operators are fused +# into one kernel (aka one ``call_tir``). +# +# For all built-in passes, please refer to :py:class:`relax.transform`. +# +# Custom Passes +# ~~~~~~~~~~~~~ +# We can also define our own passes. Let's taking an example of rewrite the ``relu`` +# operator to ``gelu`` operator. +# +# First, we need to write a Relax IR Mutator to do the rewriting. + +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@mutator +class ReluRewriter(PyExprMutator): + def __init__(self, mod): + super().__init__(mod) + + def visit_call_(self, call: relax.Call) -> relax.Expr: + # visit the relax.Call expr, and only handle the case when op is relax.nn.relu + if call.op.name == "relax.nn.relu": + return relax.op.nn.gelu(call.args[0]) + + return super().visit_call_(call) + + +###################################################################### +# Then we can write a pass to apply the mutator to the whole module. + + +@tvm.transform.module_pass(opt_level=0, name="ReluToGelu") +class ReluToGelu: # pylint: disable=too-few-public-methods + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + rewriter = ReluRewriter(mod) + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + func = rewriter.visit_expr(func) + rewriter.builder_.update_func(g_var, func) + return rewriter.builder_.get() + + +mod = ReluToGelu()(origin_mod) +mod.show() + +###################################################################### +# The printed output shows that the ``relax.nn.relu`` operator is +# rewritten to ``relax.nn.gelu`` operator. +# +# For the details of the mutator, please refer to :py:class:`relax.expr_functor.PyExprMutator`. +# +# Summary +# ~~~~~~~ +# In this section, we have shown how to apply transformations to the Relax program. +# We have also shown how to define and apply custom transformations. diff --git a/docs/deep_dive/tensor_ir/abstraction.rst b/docs/deep_dive/tensor_ir/abstraction.rst index fc11d7f39156..a832fef995f1 100644 --- a/docs/deep_dive/tensor_ir/abstraction.rst +++ b/docs/deep_dive/tensor_ir/abstraction.rst @@ -44,7 +44,6 @@ the compute statements themselves. Key Elements of Tensor Programs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - The demonstrated primitive tensor function calculates the element-wise sum of two vectors. The function: diff --git a/docs/deep_dive/tensor_ir/index.rst b/docs/deep_dive/tensor_ir/index.rst index 432d47116a3c..46bed7c42319 100644 --- a/docs/deep_dive/tensor_ir/index.rst +++ b/docs/deep_dive/tensor_ir/index.rst @@ -19,7 +19,7 @@ TensorIR ======== -TensorIR is one of the core abstraction in Apache TVM Unity stack, which is used to +TensorIR is one of the core abstraction in Apache TVM stack, which is used to represent and optimize the primitive tensor functions. .. toctree:: @@ -27,5 +27,5 @@ represent and optimize the primitive tensor functions. abstraction learning - tutorials/creation - tutorials/transformation + tutorials/tir_creation + tutorials/tir_transformation diff --git a/docs/deep_dive/tensor_ir/tutorials/creation.py b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py similarity index 100% rename from docs/deep_dive/tensor_ir/tutorials/creation.py rename to docs/deep_dive/tensor_ir/tutorials/tir_creation.py diff --git a/docs/deep_dive/tensor_ir/tutorials/transformation.py b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py similarity index 100% rename from docs/deep_dive/tensor_ir/tutorials/transformation.py rename to docs/deep_dive/tensor_ir/tutorials/tir_transformation.py diff --git a/docs/index.rst b/docs/index.rst index 2eec0cb99e97..2102bdd33a00 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -55,6 +55,7 @@ driving its costs down. :caption: Deep Dive deep_dive/tensor_ir/index + deep_dive/relax/index .. toctree:: :maxdepth: 1 From 9f281758e8a1a3c1c649b995367b0166da55f2c6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 17 Sep 2024 23:07:22 +0900 Subject: [PATCH 151/202] [CI] Upgrade PyTorch to 2.4.1 (#17338) upgrade pytorch to 2.4.1 --- docker/install/ubuntu_install_onnx.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index 2bb50c619815..6cea0075c102 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -36,6 +36,6 @@ pip3 install \ pip3 install future pip3 install \ - torch==2.0.0 \ - torchvision==0.15.1 \ + torch==2.4.1 \ + torchvision==0.19.1 \ --extra-index-url https://download.pytorch.org/whl/cpu From ff8e41644fde86714d6dbf021d57baebe3a1ec1a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 17 Sep 2024 09:07:41 -0500 Subject: [PATCH 152/202] [TVMScript] Avoid segfault from invalid TVMScript (#17373) * [TVMScript] Avoid segfault from invalid TVMScript Prior to this commit, after the `DiagnosticContext` prints its error, it overwrites the `DiagnosticRenderer` with a NULL renderer. If a second call to `DiagnosticContext::Render` occurs, it will segfault. This appears to be intended to prevent double-printing of error messages, but double-printing error messages is much worse than a segfault. In addition, `DiagnosticContext::Render` should only be called once. There's a common pattern in the parser where it will wrap exceptions in `DiagnosticError`, but re-raise exceptions that are already a `DiagnosticError`. This requires every such location to include `except DiagnosticError: raise`, and can easily be missed. This PR makes two changes: First, the `DiagnosticRenderer` is updated to have a no-op callback rather than a NULL callback. Second, the re-raising of `DiagnosticError` is moved to `Parser.report_error`, so that it does not need to be handled separately at several independent locations in the TVMScript parser. --- python/tvm/script/parser/core/evaluator.py | 12 ++++++------ python/tvm/script/parser/core/parser.py | 19 ++++++++++--------- python/tvm/script/parser/relax/parser.py | 10 +++++----- src/ir/diagnostic.cc | 3 ++- tests/python/relax/test_tvmscript_parser.py | 14 +++++++++++--- .../test_tvmscript_printer_highlight.py | 8 +++++--- 6 files changed, 39 insertions(+), 27 deletions(-) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 26e9d091bfb8..7a194c779d96 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -267,8 +267,8 @@ def _visit(self, node: doc.AST) -> Any: value = self._eval_slice(fields) else: value = self._eval_expr(node.__class__(**fields)) - except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, e) + except Exception as err: # pylint: disable=broad-except + self.parser.report_error(node, err) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: @@ -286,8 +286,8 @@ def _eval_lambda(self, node: doc.Lambda) -> Any: """ try: value = self._eval_expr(node) - except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + except Exception as err: # pylint: disable=broad-except + self.parser.report_error(node, err) return self._add_intermediate_result(value) def _eval_bool_op(self, fields: Dict[str, Any]) -> Any: @@ -463,8 +463,8 @@ def eval_assign( """ try: return _eval_assign(target, source) - except Exception as e: # pylint: disable=broad-except,invalid-name - parser.report_error(target, f"Failed to evaluate assignment: {str(e)}") + except Exception as err: # pylint: disable=broad-except + parser.report_error(target, err) raise diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 0ecf669566a2..372a3c54e4c5 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -307,10 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: def _wrapper(self: "Parser", node: doc.AST) -> None: try: return func(self, node) - except DiagnosticError: - raise - except Exception as e: # pylint: disable=broad-except,invalid-name - self.report_error(node, e) + except Exception as err: # pylint: disable=broad-except + self.report_error(node, err) raise return _wrapper @@ -547,6 +545,12 @@ def report_error( err: Union[Exception, str] The error to report. """ + + # If the error is already being raised as a DiagnosticError, + # re-raise it without wrapping it in a DiagnosticContext. + if isinstance(err, DiagnosticError): + raise err + # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] @@ -595,11 +599,8 @@ def visit(self, node: doc.AST) -> None: raise NotImplementedError(f"Visitor of AST node is not implemented: {name}") try: func(node) - except DiagnosticError: - raise - except Exception as e: # pylint: disable=broad-except,invalid-name - self.report_error(node, str(e)) - raise + except Exception as err: # pylint: disable=broad-except + self.report_error(node, err) def visit_body(self, node: List[doc.stmt]) -> Any: """The general body visiting method. diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 08269ddeeb65..011136d5d377 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -104,9 +104,9 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: try: annotation = self.eval_expr(node) return _normalize_struct_info_proxy(annotation) - except Exception as err: - self.report_error(node, str(err)) - raise err + except Exception as err: # pylint: disable=broad-except + self.report_error(node, err) + raise def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: @@ -114,9 +114,9 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St try: struct_info = self.eval_expr(node) return _normalize_struct_info(struct_info, var_table) - except Exception as err: + except Exception as err: # pylint: disable=broad-except self.report_error(node, err) - raise err + raise def is_called(node: Any, func_name: str) -> bool: diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 9245ec9c0b2f..8eeb4b3e6fd6 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -127,7 +127,8 @@ void DiagnosticContext::Render() { } if (errs) { - (*this)->renderer = DiagnosticRenderer(); + (*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {}); + // (*this)->diagnostics.clear(); LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " << "emitted, please check diagnostic render for output."; } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64f2efd4af9e..fd465f320191 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -179,6 +179,15 @@ def f(x: R.Tensor): return x +def test_incorrect_tensor_shape(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor([16])): + y: R.Tensor(16) = R.add(x, x) + return y + + def test_simple_module(): @I.ir_module class TestModule: @@ -1045,7 +1054,6 @@ def main( def test_call_tir_inplace_with_tuple_var_raises_error(): - with pytest.raises(tvm.error.DiagnosticError): @tvm.script.ir_module @@ -1838,7 +1846,7 @@ def mul_add(x: R.Tensor) -> R.Tensor: _check(InputModule, OutputModule) -def test_context_aware_parsing(): +def test_context_aware_parsing(monkeypatch): @tvm.script.ir_module class Module: @T.prim_func @@ -1863,7 +1871,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 def _break_env(self, *args): raise RuntimeError("Fail to pass context-aware parsing") - tvm.ir.GlobalVar.__call__ = _break_env + monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env) _check(Module) diff --git a/tests/python/tvmscript/test_tvmscript_printer_highlight.py b/tests/python/tvmscript/test_tvmscript_printer_highlight.py index 16e90c3563fc..4c33b435f053 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_highlight.py +++ b/tests/python/tvmscript/test_tvmscript_printer_highlight.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import relay from tvm.script import tir as T -from tvm.script.highlight import cprint +from tvm.script.highlight import cprint, _format def test_highlight_script(): @@ -58,12 +58,14 @@ def test_cprint(): # Print nodes with `script` method, e.g. PrimExpr cprint(tvm.tir.Var("v", "int32") + 1) - # Cannot print non-Python-style codes if black installed + # Cannot print non-Python-style codes when using the black + # formatter. This error comes from `_format`, used internally by + # `cprint`, and doesn't occur when using the `ruff` formatter. try: import black with pytest.raises(ValueError): - cprint("if (a == 1) { a +=1; }") + _format("if (a == 1) { a +=1; }", formatter="black") except ImportError: pass From a24204640efe3dcf519ca3388633a8a62a7600eb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 18 Sep 2024 13:01:43 -0500 Subject: [PATCH 153/202] [TVMScript][Relax] Allow return statement in DataflowBlock (#17131) Prior to this commit, TVMScript required the return value of a Relax to be specified outside of any `with R.dataflow()` blocks. This resulted in a common pattern, where the return value of a function was first called with `R.output(ret_value)`, to mark `ret_value` as a `tvm::relax::Var` instead of a `tvm::relax::DataflowVar`, followed immediately by a `return ret_value` statement. This commit updates the TVMScript parser to allow a `return` statement inside a `with R.dataflow()` block. This is syntactic sugar that is equivalent to calling `R.output`, followed by a `return`. With this change, the following two TVMScript examples are now equivalent. (Prior to this change, the `return_inside_dataflow` example would raise an error during parsing.) ```python @R.function(private=True) def output_then_return(A: R.Tensor): with R.dataflow(): B = R.add(A, A) C = R.multiply(B, B) R.output(C) return C @R.function(private=True) def return_inside_dataflow(A: R.Tensor): with R.dataflow(): B = R.add(A, A) C = R.multiply(B, B) return C ``` --- src/script/ir_builder/relax/frame.cc | 69 +++++++++------------ src/script/ir_builder/relax/ir.cc | 23 ++++--- tests/python/relax/test_tvmscript_parser.py | 31 +++++++++ 3 files changed, 75 insertions(+), 48 deletions(-) diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 3153c0770e38..faf6bd6466ad 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() { } } -class DataflowBlockRewriter : public tvm::relax::ExprMutator { +class VarReplacer : public tvm::relax::ExprMutator { public: - static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block, - const Array& output_vars) { - DataflowBlockRewriter rewriter(output_vars); - return Downcast(rewriter.VisitBindingBlock(block)); + explicit VarReplacer( + std::unordered_map + var_remap) { + var_remap_ = std::move(var_remap); } - private: - explicit DataflowBlockRewriter(const Array& output_vars) { - for (const tvm::relax::Var& var : output_vars) { - output_var_set_.insert(var.get()); - } - } - - tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final { - auto it = output_var_set_.find(op); - if (it != output_var_set_.end()) { - // Rewrite dataflow vars to global vars - auto n = make_object(*op); - tvm::relax::Var new_var(n); - this->var_remap_[op->vid] = new_var; - return new_var; + tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override { + // ExprMutator only applies var_remap_ at usage sites. This + // applies var_remap_ at each definition site as well. + if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) { + return it->second; } else { - return GetRef(op); + return var; } } - - private: - std::unordered_set output_var_set_; }; void BlockFrameNode::ExitWithScope() { @@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() { // Step 3. Rewrite the dataflow block. if (is_dataflow) { - // Step 3.1. Rewrite block binding - block = DataflowBlockRewriter::Rewrite(Downcast(block), output_vars); - - // Step 3.2. Collect global vars' reference in bindings - Map new_global_vars; - for (const tvm::relax::Binding& binding : block->bindings) { - if (!binding->var->IsInstance()) { - new_global_vars.Set(binding->var->vid, binding->var); - } + // Step 3.0. Define a map to replace variables + Array new_output_vars; + std::unordered_map var_remap; + for (const auto& output_var : output_vars) { + tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var)); + new_output_vars.push_back(new_output_var); + var_remap[output_var->vid] = new_output_var; } + VarReplacer mutator(std::move(var_remap)); + + // Step 3.1. Rewrite block binding + block = mutator.VisitBindingBlock(block); // Step 3.3. Rewrite output vars - Array new_output_vars; - for (const auto& var : output_vars) { - auto it = new_global_vars.find(var->vid); - ICHECK(it != new_global_vars.end()); - new_output_vars.push_back((*it).second); - } output_vars = std::move(new_output_vars); + + // Step 3.4 Rewrite usage of output var, if any + auto function = FindFunctionFrame("R.dataflow()"); + if (function->output.defined()) { + function->output = mutator.VisitExpr(function->output.value()); + } } // Step 3. Get the last frame from the IRBuilder frame stack. @@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() { // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { - ICHECK(!seq_frame->output.defined()) - << "The function is not expected to have output values when emitting blocks."; auto frame = GetRef(seq_frame); frame->binding_blocks.push_back(block); } else { diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 453c7fdb5522..b2e75d0c3698 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) { const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); tvm::relax::Expr normalized_value = block_builder->Normalize(value); + IRBuilder ir_builder = IRBuilder::Current(); + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of // a function body. Therefore if there is any unended block frame when dealing with function // return, we should end the block frame. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); - if (block_frame.defined()) { - block_frame.value()->ExitWithScope(); - ICHECK(!IRBuilder::Current()->FindFrame()) - << "ValueError: Relax functions don't support return in true/false branch of If Node."; + + if (auto opt = ir_builder->GetLastFrame()) { + auto block_frame = opt.value(); + for (const auto& var : tvm::relax::FreeVars(normalized_value)) { + if (var->IsInstance()) { + block_frame->output_vars.push_back(var); + } + } } // Step 2. Add the output value to the function frame. FunctionFrame frame = FindFunctionFrame("return"); CHECK(!frame->output.defined()) - << "ValueError: Relax functions don't support multiple return statement. Please make sure " - "the return statement appears at the end of function."; + << "ValueError: " + << "Relax functions do not support multiple return statement. " + << "However, return of " << normalized_value << " occurred after a return of " + << frame->output << ". " + << "Please make sure function only has a single return statement, " + << "which appears at the end of function."; frame->output = std::move(normalized_value); } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index fd465f320191..fa62d1484893 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2410,5 +2410,36 @@ def inferred_sinfo( tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) +def test_return_from_dataflow_block(): + """Return statements imply + + The `R.output` statement in a `R.dataflow()` block marks a + variable that should be a `relax.Var` instead of a + `relax.DataflowVar`, allowing it to be used outside of the + `DataflowBlock` that defined it. A relax function's output is not + part of any binding, and must not contain any `DataflowVar`, so + these are exposed implicitly. + + """ + + @R.function(private=True) + def output_then_return(A: R.Tensor([16], "float16")): + with R.dataflow(): + B = R.add(A, A) + C = R.multiply(B, B) + R.output(C) + + return C + + @R.function(private=True) + def return_inside_dataflow(A: R.Tensor([16], "float16")): + with R.dataflow(): + B = R.add(A, A) + C = R.multiply(B, B) + return C + + tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow) + + if __name__ == "__main__": tvm.testing.main() From 36e3c121b7dcfae3d5d5098186a7ca96e7ff27fc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 19 Sep 2024 12:28:25 -0500 Subject: [PATCH 154/202] [Relax] Validate StructInfo annotations in well-formed check (#17331) * [Relax] Validate StructInfo annotations in well-formed check Prior to this commit, the Relax well-formed checker verified that each expression had a non-null `StructInfo` annotation, but did not perform any validation on the contents of the `StructInfo` annotation. This commit updates the Relax well-formed check to verify that the `StructInfo` annotations are accurate by comparing against the `StructInfo` that would be inferred for an expression. (This only requires that the information is accurate, not that it is complete. For example, an expression that is inferred to be `R.Tensor(shape=[128,8], dtype="float32")` may have annotation of `R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of `R.Tensor(shape=[128,8], dtype="int32")`.) * lint fix * lint fix --- src/relax/analysis/well_formed.cc | 43 ++++++++++ src/relax/op/op.cc | 21 +++-- .../python/relax/test_analysis_well_formed.py | 85 +++++++++++++++++++ tests/python/relax/test_ast_printer.py | 4 +- tests/python/relax/test_frontend_from_fx.py | 10 +-- .../relax/test_transform_decompose_ops.py | 4 +- .../test_transform_ipc_allreduce_rewrite.py | 4 +- .../relax/test_transform_legalize_ops_ccl.py | 4 +- ..._transform_legalize_ops_create_datatype.py | 34 ++++---- ...sform_legalize_ops_index_linear_algebra.py | 2 +- .../test_transform_legalize_ops_manipulate.py | 51 ++++++----- .../relax/test_transform_legalize_ops_nn.py | 38 ++++++--- ...ansform_legalize_ops_search_statistical.py | 4 +- .../relax/test_transform_realize_vdevice.py | 16 ++-- ...test_transform_static_plan_block_memory.py | 8 +- .../test_transform_to_mixed_precision.py | 12 +-- tests/python/relax/test_tvmscript_parser.py | 10 +-- tests/python/relax/test_vm_cuda_graph.py | 8 +- tests/python/relax/test_vm_multi_device.py | 14 +-- 19 files changed, 268 insertions(+), 104 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 7688c4a64291..7873d5ce2022 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor, << err.what()); } } + + if (check_struct_info_ && call->struct_info_.defined()) { + // The `InferStructInfo` method isn't currently exposed by the + // Normalizer, and can only be called indirectly by normalizing + // an expression that does not yet have `StructInfo`. + auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); + Call copied(call->op, call->args, call->attrs, call->sinfo_args); + Optional normalized = NullOpt; + try { + normalized = dummy_builder->Normalize(copied); + } catch (std::exception& err) { + Malformed(Diagnostic::Error(call) + << "Each Relax expression must be able to have its StructInfo inferred. " + << "However, inferring the struct info of expression " << GetRef(call) + << " resulted in the error: \n" + << err.what()); + } + if (normalized.defined()) { + auto inferred_struct_info = GetStructInfo(normalized.value()); + auto current_struct_info = Downcast(call->struct_info_); + + // An error should be raised if the annotated StructInfo is + // provably incorrect. This check is done using + // `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1` + // represents cases that are neither provably correct nor + // provably incorrect. If this check were replaced with + // `!IsBaseOf(...)`, cases that are correct but not provably + // so would raise an exception. + // + // For example, if a dynamic size in the inferred StructInfo + // is equivalent to the expression used in the annotated + // StructInfo, but the TIR simplifications are not sufficient + // to prove that the two expressions are equivalent, we should + // not raise an error. + if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) < + BaseCheckResult::kFailL1) { + Malformed(Diagnostic::Error(call) + << "All information in StructInfo annotations must be correct. " + << "However, while the expression " << GetRef(call) << " is annotated as " + << current_struct_info << ", the expression outputs " << inferred_struct_info); + } + } + } } void VisitExpr_(const IfNode* op) final { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 3e0f0eba313a..a7d97a59a100 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c ICHECK(call->args.size() == 1); ICHECK(call->args[0]->struct_info_.defined()); const auto* tsinfo = GetStructInfoAs(call->args[0]); - ICHECK(tsinfo && tsinfo->shape.defined()); - ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); - ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected argument to be 1-d, " - << "but " << call << " has argument " << call->args[0] - << " with struct info " << call->args[0]->struct_info_; - const IntImmNode* ndim = shape_expr->values[0].as(); - ICHECK(ndim); - return ShapeStructInfo(ndim->value); + ICHECK(tsinfo); + ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, " + << "but " << call << " has argument " << call->args[0] + << " with struct info " << call->args[0]->struct_info_; + + if (tsinfo->shape.defined()) { + ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); + const IntImmNode* ndim = shape_expr->values[0].as(); + if (ndim) { + return ShapeStructInfo(ndim->value); + } + } + return ShapeStructInfo(kUnknownNDim); } RELAY_REGISTER_OP("relax.tensor_to_shape") diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 3db3efee1afc..d9eefcfd0ef2 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -1295,5 +1295,90 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent(): assert not rx.analysis.well_formed(main) +def test_incomplete_struct_info_must_be_consistent(): + """StructInfo annotations must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(ndim=3) = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + +def test_struct_info_annotations_must_be_correct(): + """StructInfo annotations must be correct + + To be well-formed, the inferred struct info must not conflict with + the StructInfo annotations. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + +def test_struct_info_may_be_incomplete(): + """StructInfo annotations may be less specific + + The StructInfo annotations are not required to be an exact match + to the inferred StructInfo, and may provide less specific + information than the inference would provide. + + """ + + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Object = R.add(A, B) + return C + + assert rx.analysis.well_formed(Module) + + +def test_incomplete_struct_info_must_be_consistent(): + """StructInfo annotations must be accurate + + Even though StructInfo annotation may be less specific, the + information that they do contain must be correct. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main( + A: R.Tensor(shape=[128, 32], dtype="float32"), + B: R.Tensor(shape=[128, 32], dtype="float32"), + ): + C: R.Tensor(ndim=3) = R.add(A, B) + return C + + assert not rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 6005ecb0fa58..1df7dcf36f79 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -366,8 +366,8 @@ def f( ) -> R.Object: m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) - w: R.Tensor = R.multiply(z, z) - q: R.Tensor(ndim=2) = R.add(w, w) + w: R.Tensor(ndim=2) = R.multiply(z, z) + q: R.Tensor = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.shape_of(t) o: R.Object = R.call_packed( diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 78fc7abdf748..191ea4da5e56 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -79,7 +79,7 @@ def main( out_layout="NCW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4), dtype="float32") = lv3 R.output(gv) @@ -171,7 +171,7 @@ def main( out_layout="NCW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 6), dtype="float32") = lv3 R.output(gv) @@ -263,7 +263,7 @@ def main( out_layout="NCHW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1]) lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 R.output(gv) @@ -355,7 +355,7 @@ def main( out_layout="NCHW", out_dtype="float32", ) - lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, [1, 3, 1, 1]) lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3 R.output(gv) @@ -447,7 +447,7 @@ def main( out_layout="NCDHW", out_dtype="float32", ) - lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1, 1]) lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3 R.output(gv) diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index 4e5bcb82e979..2564913d79ae 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -360,14 +360,14 @@ def test_op_tensor_to_shape(): @I.ir_module class Before: @R.function - def main(t: R.Tensor(ndim=1, dtype="int64")): + def main(t: R.Tensor([3], dtype="int64")): gv: R.Shape(ndim=3) = R.tensor_to_shape(t) return gv @I.ir_module class Expected: @R.function - def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3): + def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3): x = T.int64() x_1 = T.int64() x_2 = T.int64() diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index da85423aafd7..fa68c16e691d 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -83,7 +83,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) - lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore + lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global") ) @@ -103,7 +103,7 @@ def main( alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("ipc_memory") ) - lv1: R.Tensor((m, n), dtype="float16") = R.reshape( # type: ignore + lv1: R.Tensor((m * n,), dtype="float16") = R.reshape( # type: ignore alloc, R.shape([m * n]) ) alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 9ea4d21d610d..923a8e8d9739 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -101,8 +101,8 @@ def test_scatter_from_worker0(): @tvm.script.ir_module class ScatterFromWorker0: @R.function - def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10), "float32"): - gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1) + def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5), "float32"): + gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1) return gv0 @I.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 7b2b2d2e7644..a8af295ac3b9 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -160,19 +160,19 @@ def test_full_like(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.full_like(x, v) + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full_like(x, v) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32")) + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32")) return gv @T.prim_func(private=True) - def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_full"): @@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): - gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32")) + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5, "float32")) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): - gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32")) + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="int32")) return gv @T.prim_func(private=True) - def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) - T_full[ax0, ax1] = T.float32(-5) + T_full[ax0, ax1] = T.int32(-5) # fmt: on mod = LegalizeOps()(FullLike) @@ -253,19 +253,19 @@ def test_full_like_symbolic(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"): m = T.int64() n = T.int64() - gv: R.Tensor((m, n), "float32") = R.full_like(x, v) + gv: R.Tensor((m, n), "int32") = R.full_like(x, v) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"): m = T.int64() n = T.int64() - gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32")) + gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="int32")) return gv @T.prim_func(private=True) @@ -273,13 +273,13 @@ def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): T.func_attr({"tir.noalias": True}) m = T.int64() n = T.int64() - T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) - T_full[ax0, ax1] = rxplaceholder[()] + T_full[ax0, ax1] = T.int32(rxplaceholder[()]) # fmt: on mod = LegalizeOps()(FullLike) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index d0aaddb1ca52..2f4da5cf0653 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -230,7 +230,7 @@ def test_strided_slice_no_strides(): class StridedSlice: @R.function def main(x: R.Tensor((8, 9, 10, 10), "float32")) : - gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) + gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) return gv @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index ba5d4d7d1219..a0ecd3c73dc9 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -691,9 +691,12 @@ def test_data_dependent_reshape(): @tvm.script.ir_module class DDReshape: @R.function - def main(x: R.Tensor((3, ), dtype="int64")): - lv: R.Shape([3,]) = R.tensor_to_shape(x) - gv = R.reshape(x, lv) + def main( + x: R.Tensor([2], dtype="int64"), + y: R.Tensor([16],dtype='float32'), + ): + lv: R.Shape(ndim=2) = R.tensor_to_shape(x) + gv = R.reshape(y, lv) return gv # fmt: on @@ -704,29 +707,35 @@ def main(x: R.Tensor((3, ), dtype="int64")): # fmt: off @I.ir_module class Expected: + @R.function + def main( + x: R.Tensor([2], dtype="int64"), + y: R.Tensor([16],dtype="float32"), + ) -> R.Tensor(ndim=2, dtype="float32"): + M = T.int64() + N = T.int64() + gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape(ndim=2),)) + _ = R.match_cast(gv, R.Shape([M,N])) + _ = R.shape([M,N]) + gv_1 = R.call_tir(Expected.reshape, (y,), out_sinfo=R.Tensor([M,N], dtype="float32")) + return gv_1 + @T.prim_func(private=True) def reshape( - rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: T.handle + rxplaceholder: T.Buffer(T.int64(16), "float32"), + var_T_reshape: T.handle, ): T.func_attr({"tir.noalias": True}) - x = T.int64() - T_reshape = T.match_buffer(var_T_reshape, (x,), "int64") - # with T.block("root"): - for ax0 in range(x): + M = T.int64() + N = T.int64() + T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32") + for i,j in T.grid(M,N): with T.block("T_reshape"): - v_ax0 = T.axis.spatial(x, ax0) - T.reads(rxplaceholder[v_ax0 % T.int64(3)]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] + vi,vj = T.axis.remap('SS',[i,j]) + T.reads(rxplaceholder[(vi*N + vj) % 16]) + T.writes(T_reshape[vi,vj]) + T_reshape[vi,vj] = rxplaceholder[(vi*N + vj) % 16] - @R.function - def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"): - x_1 = T.int64() - gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) - y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) - lv: R.Shape([x_1]) = R.shape([x_1]) - gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) - return gv_1 # fmt: on tvm.ir.assert_structural_equal(out_mod, Expected) @@ -914,7 +923,7 @@ def test_squeeze_no_axis(): class Squeeze: @R.function def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) : - gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x) + gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x) return gv @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 92d139d23b5d..d03d48968d90 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -33,7 +33,7 @@ def test_conv1d(): class Conv1d: @R.function def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"): - gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) + gv: R.Tensor((2, 64, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) return gv @tvm.script.ir_module @@ -210,7 +210,7 @@ def test_conv2d(): class Conv2d: @R.function def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): - gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) + gv: R.Tensor((2, 64, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) return gv @tvm.script.ir_module @@ -3298,20 +3298,32 @@ def test_nll_loss(): @tvm.script.ir_module class NLLLoss: @R.function - def main(predictions: R.Tensor((2, 3, 4, 5), "float32"), targets: R.Tensor((2, 4, 5), "int64"), weights: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): - gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) + def main( + predictions: R.Tensor((2, 3, 4, 5), "float32"), + targets: R.Tensor((2, 4, 5), "int64"), + weights: R.Tensor((3,), "float32"), + ) -> R.Tensor((), "float32"): + gv = R.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) return gv @tvm.script.ir_module class Expected: @R.function - def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5), dtype="int64"), weights: R.Tensor((4,), dtype="float32"),) -> R.Tensor((), dtype="float32"): - # block 0 + def main( + predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), + targets: R.Tensor((2, 4, 5), dtype="int64"), + weights: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), R.Tensor((), dtype="float32")) return gv @T.prim_func(private=True) - def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), rxplaceholder_2: T.Buffer(T.int64(4), "float32"), T_divide: T.Buffer((), "float32"),): + def nll_loss( + predictions: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), + targets: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), + weights: T.Buffer(T.int64(3), "float32"), + output: T.Buffer((), "float32"), + ): # function attr dict T.func_attr({"tir.noalias": True}) # body @@ -3323,9 +3335,9 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) + T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss[v_ax0, v_ax1, v_ax2]) - nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) + nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) @@ -3337,9 +3349,9 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) + T.reads(targets[v_ax0, v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2]) - nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) + nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): with T.block("nll_loss_red_1"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) @@ -3351,8 +3363,8 @@ def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 with T.block("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss_red[()], nll_loss_red_1[()]) - T.writes(T_divide[()]) - T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()] + T.writes(output[()]) + output[()] = nll_loss_red[()] / nll_loss_red_1[()] # fmt: on mod = LegalizeOps()(NLLLoss) tvm.ir.assert_structural_equal(mod, Expected) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 2a28151dbe7e..f8dab8981552 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -999,8 +999,8 @@ def test_variance_no_keepdims(): @tvm.script.ir_module class Variance: @R.function - def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4, 1), "float32"): - gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3], keepdims=False) + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv: R.Tensor((3, 4), "float32") = R.variance(x, [0, 3], keepdims=False) return gv @I.ir_module diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index 4c530d5e4931..fa642821842d 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -61,8 +61,9 @@ def foo( y1 = y x2 = x1 y2 = y1 - lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) - gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z) + x2 = R.hint_on_device(x2, tvm.cpu()) + lv0 = R.add(x2, y2) + gv = R.multiply(lv0, z) R.output(gv) return gv @@ -91,6 +92,7 @@ def foo( y1: R.Tensor((2, 3), "float32", "llvm") = y x2: R.Tensor((2, 3), "float32", "llvm") = x1 y2: R.Tensor((2, 3), "float32", "llvm") = y1 + x2: R.Tensor((2, 3), "float32", "llvm") = x2 lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z) R.output(gv) @@ -121,7 +123,8 @@ def foo( y1 = y x2 = x1 y2 = y1 - s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) + x2 = R.hint_on_device(x2, tvm.cpu()) + s = R.add(x2, y2) m = R.multiply(s, z) return m @@ -146,6 +149,7 @@ def foo( y1: R.Tensor((2, 3), "float32", "llvm") = y x2: R.Tensor((2, 3), "float32", "llvm") = x1 y2: R.Tensor((2, 3), "float32", "llvm") = y1 + x2: R.Tensor((2, 3), "float32", "llvm") = x2 s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2) m: R.Tensor((2, 3), "float32", "llvm") = R.multiply(s, z) return m @@ -275,10 +279,11 @@ def foo( z: R.Tensor((2, 3), "float32"), ) -> R.Tensor((2, 3), "float32", "cuda"): with R.dataflow(): - lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y) + lv0 = R.add(x, y) + lv0 = R.hint_on_device(lv0, tvm.cpu()) lv1 = R.to_vdevice(lv0, "cuda") lv2 = R.add(z, z) - gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2) + gv = R.multiply(lv1, lv2) R.output(gv) return gv @@ -304,6 +309,7 @@ def foo( ) -> R.Tensor((2, 3), "float32", "cuda"): with R.dataflow(): lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y) + lv0: R.Tensor((2, 3), "float32", "llvm") = lv0 lv1: R.Tensor((2, 3), "float32", "cuda") = R.to_vdevice(lv0, "cuda") lv2: R.Tensor((2, 3), "float32", "cuda") = R.add(z, z) gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index f9e632d34897..1150827b19f9 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1386,11 +1386,11 @@ def main( ) cls.cumsum(probs, lv1, alloc1) cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1 - lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed( + lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_packed( "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), - sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float"),), + sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), ) return lv1_1 @@ -1403,7 +1403,7 @@ def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.hand @R.function def main( probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32") - ) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"): + ) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"): batch_size = T.int64() vocab_size = T.int64() R.func_attr( @@ -1437,7 +1437,7 @@ def main( ) cls.cumsum(probs, lv1, alloc1) cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1 - lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed( + lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_packed( "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index ed10fc95c723..658f80a06ec5 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -906,15 +906,15 @@ def main( ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): # block 0 with R.dataflow(): - lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = R.nn.conv2d( x, w, strides=[1, 1], padding=[0, 0, 0, 0], out_dtype="float32", ) - lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) - lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) + lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = R.add(lv142, lv143) R.output(lv144) return lv144 @@ -1001,15 +1001,15 @@ def main( ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): # block 0 with R.dataflow(): - lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = R.nn.conv2d( x, w, strides=[1, 1], padding=[0, 0, 0, 0], out_dtype="float32", ) - lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) - lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) + lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = R.add(lv142, lv143) R.output(lv144) return lv144 diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index fa62d1484893..3e64c928ae61 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -882,8 +882,8 @@ def foo( ) -> R.Object: m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) - w: R.Tensor = R.multiply(z, z) - q: R.Tensor(ndim=2) = R.add(w, w) + w: R.Tensor(ndim=2) = R.multiply(z, z) + q: R.Tensor = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) @@ -902,9 +902,9 @@ def _check_struct_info(binding, expected_sinfo): sh = bindings[4].var _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) - _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) - _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) - _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2)) _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) _check_struct_info(bindings[6], relax.ObjectStructInfo()) diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 49ebcc1d05b2..b6c8cdfdeea4 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -36,13 +36,13 @@ def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"global_symbol": "main"}) gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] - alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _: R.Tuple = cls.add(x, alloc) storage1: R.Object = gv[1] gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage) gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),)) storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8")) - alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) lv4: R.Tensor((16, 16), dtype="float32") = gv2[0] _3: R.Tuple = cls.add(lv4, alloc3) lv5: R.Tensor(dtype="float32") = alloc3 @@ -71,12 +71,12 @@ def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.O cls = Module R.func_attr({"global_symbol": "cuda_graph_capture"}) lv0: R.Tensor((16, 16), dtype="float32") = alloc - alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _1: R.Tuple = cls.add(lv0, alloc1) lv1: R.Tensor(dtype="float32") = alloc1 lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,) lv3: R.Tensor(dtype="float32") = lv2[0] - alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) + alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _2: R.Tuple = cls.add(lv3, alloc2) lv4: R.Tensor(dtype="float32") = alloc2 gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,) diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index ec2fbd1cdf60..73c78d70f042 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" + from typing import List import tvm from tvm import relax @@ -61,11 +62,12 @@ def foo( z: R.Tensor((4, 5), "float32"), ) -> R.Tensor((2, 5), "float32"): with R.dataflow(): - lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y) # noqa: F722 + lv0 = R.matmul(x, y) + lv0 = R.hint_on_device(lv0, tvm.cpu(0)) lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( # noqa: F722 - lv0, "llvm:1" # noqa: F722 + lv0, "llvm:1" ) - gv = R.matmul(lv1, z) # noqa: F722 + gv = R.matmul(lv1, z) R.output(gv) return gv @@ -109,11 +111,13 @@ def foo( with R.dataflow(): lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b) # noqa: F722 lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice( # noqa: F722 - lv0, "cuda:1" # noqa: F722 + lv0, + "cuda:1", # noqa: F722 ) lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c) # noqa: F722 lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice( # noqa: F722 - lv2, "cuda:2" # noqa: F722 + lv2, + "cuda:2", # noqa: F722 ) gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d) # noqa: F722 R.output(gv) From 660fd1e47e32fc1a7614774601d1c2b8f746ac88 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 19 Sep 2024 10:29:34 -0700 Subject: [PATCH 155/202] [DOCS] More clarity on security model of RPC server (#17382) This PR updates the python docstrings to include more clarity on RPC server security model. --- python/tvm/rpc/__init__.py | 5 +++++ python/tvm/rpc/server.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/python/tvm/rpc/__init__.py b/python/tvm/rpc/__init__.py index b64ba33d9e09..91e042b55fa1 100644 --- a/python/tvm/rpc/__init__.py +++ b/python/tvm/rpc/__init__.py @@ -23,6 +23,11 @@ The test program compiles the program on local server, upload and run remote RPC server, get the result back to verify correctness. + +TVM RPC server assumes that the user is trusted and needs to be +used in a trusted network environment and encrypted channels. +It allows writings of arbitrary files into the server and provide +full remote code execution capabilities to anyone who can access this API. """ from .server import Server diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 7c1a19856211..63c0a92ab8e1 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -474,6 +474,11 @@ class Server(object): Note ---- + TVM RPC server assumes that the user is trusted and needs to be + used in a trusted network environment and encrypted channels. + It allows writings of arbitrary files into the server and provide + full remote code execution capabilities to anyone who can access this API. + The RPC server only sees functions in the tvm namespace. To bring additional custom functions to the server env, you can use server_init_callback. From 85f2cc318595b4e5f005509fbd5acf0b34c21423 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 20 Sep 2024 09:23:30 +0900 Subject: [PATCH 156/202] [Relax][PyTorch] Fix output shape of `torch.nn.functional.scaled_dot_product_attention` (#17379) * fix the testcase * transpose the output * fix msc testcase --- .../tvm/contrib/msc/core/transform/pattern.py | 12 +++++++---- .../tvm/relax/frontend/torch/fx_translator.py | 4 +++- src/contrib/msc/framework/tvm/relax_opcode.cc | 1 + .../contrib/test_msc/test_graph_build.py | 9 ++------ tests/python/relax/test_frontend_from_fx.py | 21 +++++++++++++------ 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index fdc6a628310d..135bac64ae80 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -330,7 +330,8 @@ def make_relax_attention_pattern() -> ( q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) + attention = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) + out = relax_pattern.is_op("relax.permute_dims")(attention) annotations = { "weight_q": weight_q, "weight_k": weight_k, @@ -338,7 +339,8 @@ def make_relax_attention_pattern() -> ( "q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, - "attention": out, + "attention": attention, + "out": out, } return out, annotations @@ -378,7 +380,8 @@ def make_relax_mask_attention_pattern() -> ( q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) + attention = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) + out = relax_pattern.is_op("relax.permute_dims")(attention) annotations = { "weight_q": weight_q, "weight_k": weight_k, @@ -387,7 +390,8 @@ def make_relax_mask_attention_pattern() -> ( "q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, - "attention": out, + "attention": attention, + "out": out, } return out, annotations diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 983bce0255d9..27da69dbb182 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1015,7 +1015,9 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: assert "float" in attn_mask.struct_info.dtype, msg return self.block_builder.emit( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + transpose_S_H( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) ) def _unbind(self, node: fx.Node) -> relax.Var: diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 1913e8ecda8e..73722f987701 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -107,6 +107,7 @@ class RelaxAttentionCodeGen : public RelaxOpCode { .op_list_arg(axes_key, "axes"); } stack_.op_call().op_inputs_arg(false).op_arg("scale").op_str_arg("causal_mask"); + stack_.op_call("relax.op.permute_dims").op_output_arg().op_list_arg("axes_3", "axes"); } }; diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 60c8a73dcc67..7fa71df20b45 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2362,12 +2362,7 @@ def forward(self, q_data, k_data, v_data): {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, ], "outputs": [ - { - "name": "attention", - "shape": [1, seq, 8, 64], - "dtype": "float32", - "layout": "ABCD", - } + {"name": "attention", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 4, "input": 3, "msc.attention": 1}, } @@ -2396,7 +2391,7 @@ def forward(self, q_data, k_data, v_data, mask): "outputs": [ { "name": "attention_bias", - "shape": [1, seq, 8, 64], + "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD", } diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 191ea4da5e56..2cabcba325b2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3825,7 +3825,7 @@ def main( inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3839,7 +3839,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, scale=None ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv @@ -3851,7 +3854,7 @@ def main( inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3865,7 +3868,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, inp_3, scale=None ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv @@ -3876,7 +3882,7 @@ def main( inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3890,7 +3896,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, scale=None, causal_mask="TopLeft" ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv From 931efc72b2a80d3d21c227324217de9ce76256ca Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 21 Sep 2024 08:26:09 -0700 Subject: [PATCH 157/202] [Disco] Enable float8 data type in disco (#17398) This PR enables the float8 data type in disco, except all reduce operation. Since in this PR, we pretend float8 to be uint8. --- src/runtime/disco/nccl/nccl.cc | 6 +++++- src/runtime/disco/nccl/nccl_context.h | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index a5240aa2b2c5..6ee54e14f37b 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -114,8 +114,12 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); + DataType dtype = DataType(send->dtype); + if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) { + LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; + } NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/AsNCCLDataType(dtype), /*op=*/AsNCCLRedOp(reduce_kind), in_group ? ctx->group_comm : ctx->global_comm, stream)); } diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index b874da219fe4..6c1eaf749a67 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -86,7 +86,10 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { if (dtype == DataType::Int(8)) { return ncclInt8; } - if (dtype == DataType::UInt(8)) { + if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() || + dtype == DataType::NVFloat8E5M2()) { + // For float8 data type, pretend to be uint8 in nccl. + // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8; } if (dtype == DataType::Int(32)) { From 425e15b4475b2fdb143d82d14e781c1bd68fb318 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sun, 22 Sep 2024 14:50:35 +0800 Subject: [PATCH 158/202] [MSC] Support concat with constant inputs (#17394) * add test for concat * add doc --- cmake/modules/contrib/MSC.cmake | 1 + python/tvm/contrib/msc/core/ir/translate.py | 342 ------------------ python/tvm/contrib/msc/pipeline/config.py | 172 --------- src/contrib/msc/core/ir/graph_builder.cc | 29 +- src/contrib/msc/core/ir/graph_builder.h | 14 + src/contrib/msc/core/transform/fuse_tuple.cc | 32 +- .../contrib/test_msc/test_graph_build.py | 51 +++ .../contrib/test_msc/test_translate_relax.py | 254 +++++++------ .../contrib/test_msc/test_translate_relay.py | 22 ++ .../test_msc/test_translate_tensorrt.py | 23 ++ .../contrib/test_msc/test_translate_torch.py | 23 ++ 11 files changed, 327 insertions(+), 636 deletions(-) delete mode 100644 python/tvm/contrib/msc/core/ir/translate.py delete mode 100644 python/tvm/contrib/msc/pipeline/config.py diff --git a/cmake/modules/contrib/MSC.cmake b/cmake/modules/contrib/MSC.cmake index d2dd6fc14fb1..5779ea52175b 100644 --- a/cmake/modules/contrib/MSC.cmake +++ b/cmake/modules/contrib/MSC.cmake @@ -20,6 +20,7 @@ if(USE_MSC) list(APPEND COMPILER_SRCS ${MSC_CORE_SOURCE}) tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE "src/runtime/contrib/msc/*.cc") + set_source_files_properties(${MSC_RUNTIME_SOURCE} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE}) if(USE_TENSORRT_RUNTIME) diff --git a/python/tvm/contrib/msc/core/ir/translate.py b/python/tvm/contrib/msc/core/ir/translate.py deleted file mode 100644 index b5bfa12b677a..000000000000 --- a/python/tvm/contrib/msc/core/ir/translate.py +++ /dev/null @@ -1,342 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.ir.translate""" - -from typing import Dict, Optional, Tuple, List - -import tvm -from tvm.relax.transform import BindParams -from tvm.relax import PyExprVisitor -from tvm.relax.backend.pattern_registry import get_patterns_with_prefix -from tvm.relay.expr_functor import ExprVisitor -from tvm.relay.build_module import bind_params_by_name -from tvm.relay import dataflow_pattern as relay_pattern -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from .graph import MSCGraph, MSCTensor - - -def normalize_weights( - t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph -) -> Dict[str, tvm.nd.array]: - """Normalize the weghts. - - Parameters - ---------- - t_weights: dict of - The weights extracted from IRModule. - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - - Returns - ------- - weights: dict of - The normalized weights. - """ - - def _to_data(ref_t, data): - weight_t = graph.find_tensor(ref_t.name) - if weight_t.ndim == 1: - if ref_t.ndim != weight_t.ndim: - return tvm.nd.array(data.asnumpy().reshape(weight_t.get_shape())) - return data - if ref_t.layout and weight_t.layout: - ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name - if ref_layout != weight_layout: - assert all( - l in ref_layout for l in weight_layout - ), "layout mismatch {} compare to {}".format(ref_t, weight_t) - permute = [ref_layout.index(l) for l in weight_layout] - return tvm.nd.array(data.asnumpy().transpose(*permute)) - return data - - weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)} - return weights - - -def from_relax( - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, -) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]: - """Change IRModule to MSCGraph. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - params: dict of - The parameters of the IRModule. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize the relax before translate. - - Returns - ------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The weights from the IRModule. - """ - - trans_config = trans_config or {} - build_config = build_config or {} - opt_config = opt_config or {} - entry = trans_config.get("entry", "main") - if params: - mod = BindParams("main", params)(mod) - opt_level = opt_config.get("opt_level", 1) - if opt_level > 0: - mod = tvm.transform.Sequential( - [ - tvm.relax.transform.FoldConstant(), - ] - )(mod) - patterns = get_patterns_with_prefix("msc.") - passes = [ - tvm.relax.transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=False - ), - msc_transform.SetExprName(entry_name=entry, target=trans_config.get("target", "")), - msc_transform.SetExprLayout( - trans_config.get("allow_layout_missing", True), entry_name=entry - ), - ] - mod = tvm.transform.Sequential(passes)(mod) - graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config)) - t_weights = _ffi_api.GetRelaxWeights(mod, entry) - return graph, normalize_weights(t_weights, graph) - - -def get_relay_patterns( - mod: tvm.IRModule, - entry_name: str = "main", -) -> List[Tuple[str, relay_pattern.DFPattern, callable]]: - """Filter relay patterns based on mod. - - Parameters - ---------- - mod: IRModule - The IRModule of relay. - entry_name: str - The entry name. - - Returns - ------- - patterns: list - The useful patterns for relay - """ - - class OpExtractor(ExprVisitor): - """Extract ops from expr.""" - - def extract(self, expr): - self._optypes = set() - super().visit(expr) - return self._optypes - - def visit_call(self, expr): - super().visit_call(expr) - if isinstance(expr.op, tvm.ir.Op): - self._optypes.add(expr.op.name) - - op_names = OpExtractor().extract(mod[entry_name]) - skip_tags, patterns = set(), list(tvm.relay.op.contrib.get_pattern_table("msc")) - if "nn.conv1d" not in op_names or "add" not in op_names: - skip_tags.add("msc.conv1d_bias") - if "nn.conv2d" not in op_names or "add" not in op_names: - skip_tags.add("msc.conv2d_bias") - if "nn.batch_matmul" not in op_names or "add" not in op_names: - skip_tags.add("msc.linear_bias") - if "nn.batch_matmul" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.linear")) - if "nn.dense" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.matmul")) - if "take" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.embedding")) - if "erf" not in op_names: - skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.gelu")) - valid_patterns = [p for p in patterns if p[0] not in skip_tags] - return valid_patterns - - -def from_relay( - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, -) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]: - """Change IRModule to MSCGraph. - - Parameters - ---------- - mod: IRModule - The IRModule of relay. - params: dict of - The parameters of the IRModule. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize the relay before translate. - - Returns - ------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The weights from the IRModule. - """ - - trans_config = trans_config or {} - build_config = build_config or {} - opt_config = opt_config or {} - # TODO(tong.meng): optimize before translate? - opt_level = opt_config.get("opt_level", 0) - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - if opt_level > 0: - target = opt_config.get("target", "llvm") - disabled_pass = opt_config.get("disabled_pass", []) + [ - "SimplifyInference", - "CanonicalizeOps", - "FuseOps", - "AlterOpLayout", - ] - with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - mod, params = tvm.relay.optimize(mod, target=target, params=params) - patterns = get_relay_patterns(mod) - passes = [ - tvm.relay.transform.InferType(), - tvm.relay.transform.MergeComposite(patterns), - msc_transform.SetExprName(as_relax=False), - ] - mod = tvm.transform.Sequential(passes)(mod) - graph = _ffi_api.BuildFromRelay(mod, "main", msc_utils.dump_dict(build_config)) - t_weights = _ffi_api.GetRelayWeights(mod, "main") - return graph, normalize_weights(t_weights, graph) - - -@tvm.relax.expr_functor.visitor -class BYOCChecker(PyExprVisitor): - """Checker to check if any non-target ops exist""" - - def check(self, func_names, expr): - self._func_names = func_names - self._non_target_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._non_target_exprs) == 0, "Some exprs not on target {}".format(expr) - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - if isinstance(binding.value, tvm.relax.Call): - if isinstance(binding.value.op, tvm.relax.GlobalVar): - if binding.value.op.name_hint not in self._func_names: - self._non_target_exprs.append(binding.value) - else: - self._non_target_exprs.append(binding.value) - elif not isinstance(binding.value, tvm.relax.DataflowVar): - self._non_target_exprs.append(binding.value) - - -def byoc_partition( - target: str, - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - allow_incomplete: bool = True, -) -> Tuple[tvm.IRModule, List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]]]: - """Partition module to target sub functions. - - Parameters - ---------- - target: str - The target for the BYOC. - mod: IRModule - The IRModule of relax. - trans_config: dict - The config for transform IRModule. - params: dict of - The parameters of the IRModule. - build_config: dict - The config for build MSCGraph. - allow_incomplete: bool - Whether allow some ops not on tensorrt - - - Returns - ------- - mod: IRModule - The IRModule of partitioned relax. - graphs_info: list<> - The func list, each element for a sub graph. - """ - - trans_config = trans_config or {} - build_config = build_config or {} - build_config["target"] = target - entry = trans_config.get("entry", "main") - if params: - mod = BindParams("main", params)(mod) - - def _partition_mod(mod, as_msc=True): - patterns = get_patterns_with_prefix(target) - if as_msc: - passes = [tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=False)] - else: - passes = [tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=True)] - passes.extend( - [ - msc_transform.BindShape(), - msc_transform.FuseTuple(target), - tvm.relax.transform.MergeCompositeFunctions(), - msc_transform.SetExprName(target=target), - msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), - ] - ) - return tvm.transform.Sequential(passes)(mod) - - def _is_target_func(func): - if "Codegen" not in func.attrs: - return False - return func.attrs["Codegen"] == target - - msc_mod = _partition_mod(mod) - func_names = [var.name_hint for var, func in msc_mod.functions.items() if _is_target_func(func)] - - if not allow_incomplete: - assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod) - BYOCChecker().check(func_names, msc_mod[entry]) - - graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry) - for idx, name in enumerate(func_names): - build_config.update({"graph_name": target + "_" + str(idx), "byoc_entry": name}) - graph = _ffi_api.BuildFromRelax(msc_mod, entry, msc_utils.dump_dict(build_config)) - graphs_info.append((name, graph, normalize_weights(all_weights, graph))) - return _partition_mod(mod, False), graphs_info diff --git a/python/tvm/contrib/msc/pipeline/config.py b/python/tvm/contrib/msc/pipeline/config.py deleted file mode 100644 index b6d80fd42089..000000000000 --- a/python/tvm/contrib/msc/pipeline/config.py +++ /dev/null @@ -1,172 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.pipeline.config""" - -from typing import List, Union, Dict, Tuple - -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core import utils as msc_utils - - -def support_tool(tool: dict, stage: str, run_type: str) -> bool: - """Check if the tool is supported - - Parameters - ---------- - tool: dict - The tool config, - stage: str - The compile stage. - run_type: str - The runtime type. - - Returns - ------- - supported: bool - Whether the tool is supported. - """ - - run_type = tool.get("run_type", run_type) - if stage == MSCStage.BASELINE: - return tool["tool_type"] == ToolType.TRACKER - return True - - -def config_tool(tool_type: str, raw_config: Union[dict, str]) -> dict: - """Config the tool - - Parameters - ---------- - tool_type: str - The tool type, - raw_config: str| dict - The tool config or style. - - Returns - ------- - config: dict - The config for tool. - """ - - if isinstance(raw_config, dict): - if "config_style" in raw_config: - config_style = raw_config.pop("config_style") - else: - config_style = "default" - else: - config_style, raw_config = raw_config, None - configer_cls = msc_utils.get_registered_tool_configer(tool_type, config_style) - assert configer_cls, "Can not find configer for {}:{}".format(tool_type, config_style) - return {"tool_type": tool_type, **configer_cls().config(raw_config)} - - -def create_config( - inputs: List[dict], - outputs: List[str], - model_type: str, - baseline_type: str = None, - optimize_type: str = None, - compile_type: str = None, - dataset: Dict[str, dict] = None, - tools: List[Tuple[str, Union[dict, str]]] = None, - skip_config: Dict[str, str] = None, - **extra_config, -) -> dict: - """Create config for msc pipeline - - Parameters - ---------- - inputs: list - The inputs info, - outputs: list - The output names. - model_type: str - The model type. - baseline_type: str - The baseline type. - compile_type: str - The compile type. - optimize_type: str - The optimize type. - dataset: dict - The datasets for compile pipeline. - tools: list - The tools config. - skip_config: dict - The skip config for compile. - extra_config: dict - The extra config. - """ - - baseline_type = baseline_type or model_type - optimize_type = optimize_type or baseline_type - compile_type = compile_type or optimize_type - tools = tools or [] - tools = [config_tool(t_type, t_config) for t_type, t_config in tools] - # basic config - config = { - "model_type": model_type, - "inputs": inputs, - "outputs": outputs, - "dataset": dataset, - "tools": tools, - MSCStage.PREPARE: {"profile": {"benchmark": {"repeat": -1}}}, - MSCStage.BASELINE: { - "run_type": baseline_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - }, - } - - # config optimize - opt_tools = [t for t in tools if support_tool(t, MSCStage.OPTIMIZE, optimize_type)] - if opt_tools: - config[MSCStage.OPTIMIZE] = { - "run_type": optimize_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - } - - # config compile - config[MSCStage.COMPILE] = { - "run_type": compile_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - } - - # update config - if extra_config: - config = msc_utils.update_dict(config, extra_config) - - # skip stages - skip_config = skip_config or {} - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in config: - continue - for key in ["all", stage]: - if key not in skip_config: - continue - if skip_config[key] == "stage": - config.pop(stage) - elif skip_config[key] == "profile": - config[stage].pop("profile") - elif skip_config[key] == "check": - config[stage]["profile"].pop("check") - elif skip_config[key] == "benchmark": - config[stage]["profile"].pop("benchmark") - else: - raise TypeError("Unexpected skip type " + str(skip_config[key])) - - return config diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 20c7dbcc9172..abb7dfbd5e02 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -294,6 +294,25 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional layout = layouts_[node_name]; } + // specail case for tuple + if (optype == "tuple" && expr->IsInstance() && + Downcast(expr)->op->IsInstance()) { + const auto& call_node = Downcast(expr); + ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + const auto& tuple_func = target_funcs_[call_node->op]; + for (size_t i = 0; i < call_node->args.size(); i++) { + expr_tensor_map_.Set(tuple_func->params[i], expr_tensor_map_[call_node->args[i]]); + } + VisitExpr(tuple_func); + ICHECK(expr_tensor_map_.count(tuple_func->body->body)) + << "Can not find seqexpr body " << tuple_func->body->body; + const auto& outputs = expr_tensor_map_[tuple_func->body->body]; + const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr; + expr_tensor_map_.Set(ref_expr, outputs); + ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0]; + return Downcast(tensor_input_map_[outputs[0]].first); + } + // get plugin const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin(); @@ -814,6 +833,14 @@ void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) { weights_.Set(weight, op->data); } +void RelaxWeightsExtractor::VisitExpr_(const relax::CallNode* op) { + RelaxExprVisitor::VisitExpr_(op); + if (const auto* v_node = op->op.as()) { + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + VisitExpr(func); + } +} + void RelayFuncAttrGetter::VisitExpr_(const relay::CallNode* op) { RelayExprVisitor::VisitExpr_(op); if (op->attrs.defined()) { @@ -1163,7 +1190,7 @@ TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights") .set_body_typed([](const IRModule& relax_module, const String& entry_name) -> Map { const auto& func = Downcast(relax_module->Lookup(entry_name)); - return RelaxWeightsExtractor().GetWeights(func); + return RelaxWeightsExtractor(relax_module).GetWeights(func); }); TVM_REGISTER_GLOBAL("msc.core.BuildFromRelay") diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 250fa38ef91b..269a8a213ce8 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -325,13 +325,27 @@ class RelaxGraphBuilder : public RelaxExprVisitor { class RelaxWeightsExtractor : public RelaxExprVisitor { public: + /*! + * \brief The constructor of RelaxGraphBuilder + * \param ref_module the reference module. + * \param name the name of the graph. + * \param options the options of build the graph. + */ + explicit RelaxWeightsExtractor(const IRModule& ref_module) : RelaxExprVisitor() { + ref_module_ = ref_module; + } + /*! \brief Visit the constant and save weights */ Map GetWeights(const relax::Function& func); void VisitExpr_(const relax::ConstantNode* op) final; + void VisitExpr_(const relax::CallNode* op) final; + private: Map weights_; + Map local_funcs_; + IRModule ref_module_; }; class RelayFuncAttrGetter : public RelayExprVisitor { diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index be1a10718c98..6c82c589c82a 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -70,9 +70,20 @@ class TupleFuser : public ExprMutator { bool has_tuple_arg = false; if (target_funcs_.count(val->op)) { Array new_args; - for (const auto& arg : val->args) { + for (size_t i = 0; i < val->args.size(); i++) { + const auto& arg = val->args[i]; if (arg->IsInstance()) { - const auto& func_call = AddFunc(arg); + String tuple_name; + const auto& name_opt = + target_funcs_[val->op]->GetAttr(msc_attr::kUnique); + if (name_opt.defined()) { + if (val->args.size() == 1) { + tuple_name = name_opt.value() + "_input"; + } else { + tuple_name = name_opt.value() + "_inputs." + std::to_string(i); + } + } + const auto& func_call = AddFunc(arg, tuple_name); const auto& tuple_out = builder_->Emit(func_call); ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; @@ -118,7 +129,7 @@ class TupleFuser : public ExprMutator { } private: - Call AddFunc(const Expr& expr) { + Call AddFunc(const Expr& expr, const String tuple_name = "") { builder_->BeginDataflowBlock(); Array inputs; if (const auto* v_node = expr.as()) { @@ -133,6 +144,10 @@ class TupleFuser : public ExprMutator { Array params; Map added_params; for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i]->IsInstance()) { + func_inputs.push_back(inputs[i]); + continue; + } if (!added_params.count(inputs[i])) { const auto& name = String("param_" + std::to_string(i)); const auto& var = Var(std::move(name), GetStructInfo(inputs[i])); @@ -145,11 +160,16 @@ class TupleFuser : public ExprMutator { Expr out_expr; String func_name; + Span expr_span = expr->span; + if (!expr_span.defined()) { + ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; + expr_span = SpanUtils::CreateWithAttr(msc_attr::kName, tuple_name); + } if (expr->IsInstance()) { - out_expr = Tuple(func_inputs, expr->span); + out_expr = Tuple(func_inputs, expr_span); func_name = "tuple"; } else if (const auto* g_node = expr.as()) { - out_expr = TupleGetItem(func_inputs[0], g_node->index, expr->span); + out_expr = TupleGetItem(func_inputs[0], g_node->index, expr_span); func_name = "get_item"; } else { LOG_FATAL << "Unexpceted expr " << expr; @@ -163,7 +183,7 @@ class TupleFuser : public ExprMutator { Map func_attrs; func_attrs.Set(attr::kPrimitive, Integer(1)); func_attrs.Set(attr::kComposite, target_ + func_name); - func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr->span, msc_attr::kName)); + func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr_span, msc_attr::kName)); Function function = Function(/*params=*/params, // /*body=*/body, // diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 7fa71df20b45..76e3147a5507 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2338,6 +2338,57 @@ def forward(self, x, y): verify_model(Max(), [([bz, 256], "float32"), ([bz, 256], "float32")], expected) +@pytest.mark.parametrize("dynamic", [True, False]) +def test_cat(dynamic): + """test graph builder for cat""" + + class Cat1(Module): + def forward(self, data, data1, data2): + return torch.cat((data, data1, data2), dim=1) + + class Cat2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.cat((data, const1, const2), dim=1) + + bz = "bz" if dynamic else 1 + dim = "dim" if dynamic else 3 + input_info = [ + ([bz, dim, 10, 10], "float32"), + ([bz, dim, 10, 10], "float32"), + ([bz, dim, 10, 10], "float32"), + ] + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_2", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, + ], + "outputs": [ + { + "name": "concat", + "shape": [bz, "MUL_3" if dynamic else 9, 10, 10], + "dtype": "float32", + "layout": "ABCD", + } + ], + "nodes": {"total": 4, "input": 3, "concat": 1}, + } + expected2 = { + "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "outputs": [ + {"name": "concat", "shape": [1, 9, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "nodes": {"total": 4, "input": 1, "constant": 2, "concat": 1}, + } + if dynamic: + expected1["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1} + + verify_model(Cat1(), input_info, expected1) + verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2) + + @pytest.mark.parametrize("dynamic", [True, False]) def test_attention(dynamic): """test graph builder for attention""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 66aa90a625ea..64d00bb0922e 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -29,7 +29,9 @@ from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen -def _verify_model(torch_model, input_info, opt_config=None): +def verify_model(torch_model, input_info, opt_config=None): + """Compare torch module IR""" + graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): orig_mod = from_fx(graph_model, input_info) @@ -92,8 +94,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10], "float32")] - _verify_model(Conv1D1(), input_info) - _verify_model(Conv1D2(), input_info) + verify_model(Conv1D1(), input_info) + verify_model(Conv1D2(), input_info) def test_conv2d(): @@ -116,8 +118,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Conv2D1(), input_info) - _verify_model(Conv2D2(), input_info) + verify_model(Conv2D1(), input_info) + verify_model(Conv2D2(), input_info) def test_linear(): @@ -144,9 +146,9 @@ def forward(self, x, y): return torch.matmul(x, y) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Dense1(), input_info) - _verify_model(Dense2(), input_info) - _verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) + verify_model(Dense1(), input_info) + verify_model(Dense2(), input_info) + verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) def test_bmm(): @@ -157,7 +159,7 @@ def forward(self, x, y): return torch.bmm(x, y) input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - _verify_model(BMM(), input_info) + verify_model(BMM(), input_info) def test_baddbmm(): @@ -176,8 +178,8 @@ def forward(self, c, x, y): ((4, 128, 256), "float32"), ((4, 256, 512), "float32"), ] - _verify_model(BAddBMM1(), input_info) - _verify_model(BAddBMM2(), input_info) + verify_model(BAddBMM1(), input_info) + verify_model(BAddBMM2(), input_info) def test_relu(): @@ -196,8 +198,8 @@ def forward(self, data): return torch.nn.functional.relu(data) input_info = [([10, 10], "float32")] - _verify_model(ReLU(), input_info) - _verify_model(ReLU1(), input_info) + verify_model(ReLU(), input_info) + verify_model(ReLU1(), input_info) def test_relu6(): @@ -212,7 +214,7 @@ def forward(self, data): return self.relu6(data) input_info = [([10, 10], "float32")] - _verify_model(ReLU6(), input_info) + verify_model(ReLU6(), input_info) def test_maxpool2d(): @@ -243,9 +245,9 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(MaxPool2d(), input_info) - _verify_model(MaxPool2d2(), input_info) - _verify_model(MaxPool2d3(), input_info) + verify_model(MaxPool2d(), input_info) + verify_model(MaxPool2d2(), input_info) + verify_model(MaxPool2d3(), input_info) def test_avgpool2d(): @@ -268,8 +270,8 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(AvgPool2d(), input_info) - _verify_model(AvgPool2d2(), input_info) + verify_model(AvgPool2d(), input_info) + verify_model(AvgPool2d2(), input_info) def test_adaptive_avgpool2d(): @@ -284,7 +286,7 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(AdaptiveAvgPool2d0(), input_info) + verify_model(AdaptiveAvgPool2d0(), input_info) def test_flatten(): @@ -299,8 +301,8 @@ def forward(self, data): return self.f(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Flatten(), input_info) - _verify_model(torch.nn.Flatten(2, -1), input_info) + verify_model(Flatten(), input_info) + verify_model(torch.nn.Flatten(2, -1), input_info) def test_batchnorm2d(): @@ -315,7 +317,7 @@ def forward(self, data): return self.batchnorm(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(BatchNorm2d(), input_info) + verify_model(BatchNorm2d(), input_info) def test_embedding(): @@ -329,8 +331,8 @@ def __init__(self): def forward(self, data): return self.embedding(data) - _verify_model(Embedding(), [([4], "int64")]) - _verify_model(Embedding(), [([4, 5], "int64")]) + verify_model(Embedding(), [([4], "int64")]) + verify_model(Embedding(), [([4, 5], "int64")]) def test_dropout(): @@ -349,8 +351,8 @@ def forward(self, data): return torch.dropout(data, 0.5, train=True) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Dropout1(), input_info) - _verify_model(Dropout2(), input_info) + verify_model(Dropout1(), input_info) + verify_model(Dropout2(), input_info) def test_layernorm(): @@ -365,7 +367,7 @@ def forward(self, data): return self.layernorm(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(LayerNorm(), input_info) + verify_model(LayerNorm(), input_info) def test_functional_layernorm(): @@ -383,7 +385,7 @@ def forward(self, data): ) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(LayerNorm((10, 10)), input_info) + verify_model(LayerNorm((10, 10)), input_info) def test_cross_entropy(): @@ -415,9 +417,9 @@ def forward(self, logits, targets): return self.loss(logits, targets) input_info = [([3, 2], "float32"), ([3], "int32")] - _verify_model(CrossEntropy1(), input_info) - _verify_model(CrossEntropy2(), input_info) - _verify_model(CrossEntropy3(), input_info) + verify_model(CrossEntropy1(), input_info) + verify_model(CrossEntropy2(), input_info) + verify_model(CrossEntropy3(), input_info) def test_functional_cross_entropy(): @@ -428,7 +430,7 @@ def forward(self, logits, targets): return torch.nn.functional.cross_entropy(logits, targets) input_info = [([3, 10], "float32"), ([3], "int32")] - _verify_model(CrossEntropy(), input_info) + verify_model(CrossEntropy(), input_info) def test_silu(): @@ -447,8 +449,8 @@ def forward(self, data): return torch.nn.functional.silu(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(SiLU(), input_info) - _verify_model(SiLU2(), input_info) + verify_model(SiLU(), input_info) + verify_model(SiLU2(), input_info) def test_groupnorm(): @@ -463,7 +465,7 @@ def forward(self, data): return self.groupnorm(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(GroupNorm(), input_info) + verify_model(GroupNorm(), input_info) def test_softmax(): @@ -478,7 +480,7 @@ def forward(self, data): return self.softmax(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Softmax(), input_info) + verify_model(Softmax(), input_info) def test_binary(): @@ -496,8 +498,8 @@ class Add2(Module): def forward(self, lhs): return lhs + 1.0 - _verify_model(Add1(), input_info1) - _verify_model(Add2(), input_info2) + verify_model(Add1(), input_info1) + verify_model(Add2(), input_info2) # Sub class Sub1(Module): @@ -508,8 +510,8 @@ class Sub2(Module): def forward(self, lhs): return lhs - 1.0 - _verify_model(Sub1(), input_info1) - _verify_model(Sub2(), input_info2) + verify_model(Sub1(), input_info1) + verify_model(Sub2(), input_info2) # Mul class Mul1(Module): @@ -520,8 +522,8 @@ class Mul2(Module): def forward(self, lhs): return lhs * 1.0 - _verify_model(Mul1(), input_info1) - _verify_model(Mul2(), input_info2) + verify_model(Mul1(), input_info1) + verify_model(Mul2(), input_info2) # True div class TrueDiv1(Module): @@ -532,8 +534,8 @@ class TrueDiv2(Module): def forward(self, lhs): return lhs / 1.0 - _verify_model(TrueDiv1(), input_info1) - _verify_model(TrueDiv2(), input_info2) + verify_model(TrueDiv1(), input_info1) + verify_model(TrueDiv2(), input_info2) # Floor div class FloorDiv1(Module): @@ -544,8 +546,8 @@ class FloorDiv2(Module): def forward(self, lhs): return lhs // 1.0 - _verify_model(FloorDiv1(), input_info1) - _verify_model(FloorDiv2(), input_info2) + verify_model(FloorDiv1(), input_info1) + verify_model(FloorDiv2(), input_info2) # Power class Power1(Module): @@ -556,8 +558,8 @@ class Power2(Module): def forward(self, lhs): return lhs**1.0 - _verify_model(Power1(), input_info1) - _verify_model(Power2(), input_info2) + verify_model(Power1(), input_info1) + verify_model(Power2(), input_info2) # LT class LT1(Module): @@ -568,8 +570,8 @@ class LT2(Module): def forward(self, lhs): return lhs < 1.0 - _verify_model(LT1(), input_info1) - _verify_model(LT2(), input_info2) + verify_model(LT1(), input_info1) + verify_model(LT2(), input_info2) def test_size(): @@ -580,7 +582,7 @@ def forward(self, data): return data.size() input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Size(), input_info) + verify_model(Size(), input_info) def test_squeeze(): @@ -595,8 +597,8 @@ def forward(self, data): return data.squeeze() input_info = [([3, 1, 4, 1], "float32")] - _verify_model(Squeeze1(), input_info) - _verify_model(Squeeze2(), input_info) + verify_model(Squeeze1(), input_info) + verify_model(Squeeze2(), input_info) def test_unsqueeze(): @@ -611,8 +613,8 @@ def forward(self, data): return data.unsqueeze(-1) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Unsqueeze1(), input_info) - _verify_model(Unsqueeze2(), input_info) + verify_model(Unsqueeze1(), input_info) + verify_model(Unsqueeze2(), input_info) def test_getattr(): @@ -623,7 +625,7 @@ def forward(self, data): return data.shape input_info = [([1, 3, 10, 10], "float32")] - _verify_model(GetAttr1(), input_info) + verify_model(GetAttr1(), input_info) def test_getitem(): @@ -637,8 +639,8 @@ class Slice2(Module): def forward(self, x): return x[:, None, None, :, None] - _verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - _verify_model(Slice2(), [([8, 16], "float32")]) + verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) + verify_model(Slice2(), [([8, 16], "float32")]) def test_unary(): @@ -651,42 +653,42 @@ class Sin(Module): def forward(self, data): return torch.sin(data) - _verify_model(Sin(), input_info) + verify_model(Sin(), input_info) # cos class Cos(Module): def forward(self, data): return torch.cos(data) - _verify_model(Cos(), input_info) + verify_model(Cos(), input_info) # exp class Exp(Module): def forward(self, data): return torch.exp(data) - _verify_model(Exp(), input_info) + verify_model(Exp(), input_info) # sqrt class Sqrt(Module): def forward(self, data): return torch.sqrt(data) - _verify_model(Sqrt(), input_info) + verify_model(Sqrt(), input_info) # sigmoid class Sigmoid(Module): def forward(self, data): return torch.sigmoid(data) - _verify_model(Sigmoid(), input_info) + verify_model(Sigmoid(), input_info) # round class Round(Module): def forward(self, data): return torch.round(data) - _verify_model(Round(), input_info) + verify_model(Round(), input_info) def test_gelu(): @@ -697,7 +699,7 @@ def forward(self, data): return torch.nn.functional.gelu(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Gelu(), input_info) + verify_model(Gelu(), input_info) def test_tanh(): @@ -708,7 +710,7 @@ def forward(self, data): return torch.tanh(data) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Tanh(), input_info) + verify_model(Tanh(), input_info) def test_clamp(): @@ -719,7 +721,7 @@ def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Clamp(), input_info) + verify_model(Clamp(), input_info) def test_interpolate(): @@ -730,7 +732,7 @@ def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Interpolate(), input_info) + verify_model(Interpolate(), input_info) def test_addmm(): @@ -745,7 +747,7 @@ def forward(self, x_1, x_2, x_3): ([10, 10], "float32"), ([10, 10], "float32"), ] - _verify_model(Addmm(), input_info) + verify_model(Addmm(), input_info) def test_split(): @@ -760,8 +762,8 @@ def forward(self, data): return torch.split(data, [1, 2], dim=1) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Split1(), input_info) - _verify_model(Split2(), input_info) + verify_model(Split1(), input_info) + verify_model(Split2(), input_info) def test_unbind(): @@ -776,8 +778,8 @@ def forward(self, data): return torch.unbind(data, dim=1) input_info = [([3, 3, 10, 10], "float32")] - _verify_model(Unbind1(), input_info) - _verify_model(Unbind2(), input_info) + verify_model(Unbind1(), input_info) + verify_model(Unbind2(), input_info) def test_cumsum(): @@ -788,7 +790,7 @@ def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Cumsum(), input_info) + verify_model(Cumsum(), input_info) def test_chunk(): @@ -799,7 +801,7 @@ def forward(self, data): return torch.chunk(data, 3, dim=1) input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Chunk(), input_info) + verify_model(Chunk(), input_info) def test_inplace_fill(): @@ -810,7 +812,7 @@ def forward(self, data): data.fill_(1.5) return data - _verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) + verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) def test_arange(): @@ -820,7 +822,7 @@ class Arange(Module): def forward(self): return torch.arange(0, 20, dtype=torch.int32) - _verify_model(Arange(), [([10, 10], "float32")]) + verify_model(Arange(), [([10, 10], "float32")]) def test_empty(): @@ -830,7 +832,7 @@ class Empty(Module): def forward(self): return torch.empty((10, 10), dtype=torch.float32) - _verify_model(Empty(), [([10, 10], "float32")]) + verify_model(Empty(), [([10, 10], "float32")]) def test_tensor(): @@ -844,8 +846,8 @@ class Empty2(Module): def forward(self): return torch.tensor(3) - _verify_model(Empty1(), [([10, 10], "float32")]) - _verify_model(Empty2(), [([10, 10], "float32")]) + verify_model(Empty1(), [([10, 10], "float32")]) + verify_model(Empty2(), [([10, 10], "float32")]) def test_tril(): @@ -861,8 +863,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - _verify_model(Tril(), input_info) - _verify_model(InplaceTril(), input_info) + verify_model(Tril(), input_info) + verify_model(InplaceTril(), input_info) def test_triu(): @@ -878,8 +880,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - _verify_model(Triu(), input_info) - _verify_model(InplaceTriu(), input_info) + verify_model(Triu(), input_info) + verify_model(InplaceTriu(), input_info) def test_new_ones(): @@ -890,7 +892,7 @@ def forward(self, x): return x.new_ones(1, 2, 3) input_info = [([1, 2, 3], "float32")] - _verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) + verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) def test_expand(): @@ -905,8 +907,8 @@ def forward(self, x): return x.expand(4, -1, -1, 4) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Expand1(), input_info) - _verify_model(Expand2(), input_info) + verify_model(Expand1(), input_info) + verify_model(Expand2(), input_info) def test_reduce(): @@ -918,7 +920,7 @@ def forward(self, x): return torch.sum(x, (2, 1)) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Sum(), input_info) + verify_model(Sum(), input_info) def test_datatype(): @@ -931,14 +933,14 @@ class ToFloat(Module): def forward(self, x): return x.float() - _verify_model(ToFloat(), input_info) + verify_model(ToFloat(), input_info) # half class ToHalf(Module): def forward(self, x): return x.half() - _verify_model(ToHalf(), input_info) + verify_model(ToHalf(), input_info) # type class Type(Module): @@ -955,9 +957,9 @@ class AsType(Module): def forward(self, x): return x.astype(torch.float32) - _verify_model(Type(), input_info) - _verify_model(TypeFromAttr(), input_info) - _verify_model(AsType(), input_info) + verify_model(Type(), input_info) + verify_model(TypeFromAttr(), input_info) + verify_model(AsType(), input_info) def test_permute(): @@ -968,7 +970,7 @@ def forward(self, x): return x.permute(0, 3, 2, 1) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Permute(), input_info) + verify_model(Permute(), input_info) def test_reshape(): @@ -979,7 +981,7 @@ def forward(self, x): return x.reshape(2, 12) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Reshape(), input_info) + verify_model(Reshape(), input_info) def test_transpose(): @@ -990,7 +992,7 @@ def forward(self, x): return x.transpose(1, 3) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Transpose(), input_info) + verify_model(Transpose(), input_info) def test_view(): @@ -1001,7 +1003,7 @@ def forward(self, x): return x.view(2, 12) input_info = [([1, 2, 3, 4], "float32")] - _verify_model(View(), input_info) + verify_model(View(), input_info) def test_keep_params(): @@ -1015,7 +1017,7 @@ def __init__(self): def forward(self, data): return self.conv(data) - _verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) + verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) def test_unwrap_unit_return_tuple(): @@ -1025,7 +1027,7 @@ class Identity(Module): def forward(self, x): return (x,) - _verify_model(Identity(), [([256, 256], "float32")]) + verify_model(Identity(), [([256, 256], "float32")]) def test_no_bind_return_tuple(): @@ -1036,7 +1038,7 @@ def forward(self, x, y): return (x, y) input_info = [([256, 256], "float32"), ([256, 256], "float32")] - _verify_model(Identity(), input_info) + verify_model(Identity(), input_info) def test_argmax(): @@ -1050,8 +1052,8 @@ class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) - _verify_model(Argmax1(), [([256, 256], "float32")]) - _verify_model(Argmax2(), [([256, 256], "float32")]) + verify_model(Argmax1(), [([256, 256], "float32")]) + verify_model(Argmax2(), [([256, 256], "float32")]) def test_argmin(): @@ -1065,8 +1067,8 @@ class Argmin2(Module): def forward(self, data): return torch.argmin(data, keepdim=True) - _verify_model(Argmin1(), [([256, 256], "float32")]) - _verify_model(Argmin2(), [([256, 256], "float32")]) + verify_model(Argmin1(), [([256, 256], "float32")]) + verify_model(Argmin2(), [([256, 256], "float32")]) def test_to(): @@ -1080,8 +1082,8 @@ class To2(Module): def forward(self, data): return data.to("cpu") - _verify_model(To1(), [([256, 256], "float32")]) - _verify_model(To2(), [([256, 256], "float32")]) + verify_model(To1(), [([256, 256], "float32")]) + verify_model(To2(), [([256, 256], "float32")]) def test_mean(): @@ -1095,8 +1097,8 @@ class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) - _verify_model(Mean(), [([256, 256], "float32")]) - _verify_model(MeanKeepDim(), [([256, 256], "float32")]) + verify_model(Mean(), [([256, 256], "float32")]) + verify_model(MeanKeepDim(), [([256, 256], "float32")]) def test_rsqrt(): @@ -1106,7 +1108,7 @@ class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) - _verify_model(Rsqrt(), [([256, 256], "float32")]) + verify_model(Rsqrt(), [([256, 256], "float32")]) def test_neg(): @@ -1116,7 +1118,7 @@ class Neg(Module): def forward(self, data): return -data - _verify_model(Neg(), [([256, 256], "float32")]) + verify_model(Neg(), [([256, 256], "float32")]) def test_max(): @@ -1126,7 +1128,29 @@ class Max(Module): def forward(self, x, y): return torch.max(x, y) - _verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) + verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) + + +def test_cat(): + """test relax translator for cat""" + + class Cat1(Module): + def forward(self, data, data1, data2): + return torch.cat((data, data1, data2), dim=1) + + class Cat2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.cat((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + verify_model(Cat1(), input_info) + verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) def test_attention(): @@ -1148,14 +1172,14 @@ def forward(self, q_data, k_data, v_data): ([32, 8, 128, 64], "float32"), ([32, 8, 128, 64], "float32"), ] - _verify_model(Attention1(), input_info) - _verify_model(Attention2(), input_info) + verify_model(Attention1(), input_info) + verify_model(Attention2(), input_info) class Attention3(Module): def forward(self, q_data, k_data, v_data, mask): return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - _verify_model( + verify_model( Attention3(), [ ([32, 8, 128, 64], "float32"), diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 3790da3f3d8e..ebba339a4a3e 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -1086,6 +1086,28 @@ def forward(self, x, y): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) +def test_cat(): + """test relay to relax for cat""" + + class Cat1(Module): + def forward(self, data, data1, data2): + return torch.cat((data, data1, data2), dim=1) + + class Cat2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.cat((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + verify_model(Cat1(), input_info, build_target="llvm") + verify_model(Cat2(), [([1, 3, 10, 10], "float32")], build_target="llvm") + + def test_name_string_with_colon(): """test name string with colons, e.g., TFLite default input name 'serving_default_input:0' diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 7c8c2830995c..6d87ca8753dc 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -893,5 +893,28 @@ def forward(self, data): verify_model(Gelu2(), input_info) +@requires_tensorrt +def test_cat(): + """test tensorrt translator for cat""" + + class Cat1(Module): + def forward(self, data, data1, data2): + return torch.cat((data, data1, data2), dim=1) + + class Cat2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.cat((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + verify_model(Cat1(), input_info) + verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index f3e01493d96a..55bae682ef20 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -1105,6 +1105,29 @@ def forward(self, x, y): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], via_relax) +def test_cat(): + """test torch translator for cat""" + + class Cat1(Module): + def forward(self, data, data1, data2): + return torch.cat((data, data1, data2), dim=1) + + class Cat2(Module): + def forward(self, data): + const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) + return torch.cat((data, const1, const2), dim=1) + + input_info = [ + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ([1, 3, 10, 10], "float32"), + ] + for via_relax in [True, False]: + verify_model(Cat1(), input_info, via_relax) + verify_model(Cat2(), [([1, 3, 10, 10], "float32")], via_relax) + + def test_attention(): """test torch translator for attention""" From 72d542e71c628bc3d6bd983c2cd753a663b521a6 Mon Sep 17 00:00:00 2001 From: XinhuaHamiMelon Date: Sun, 22 Sep 2024 14:55:45 +0800 Subject: [PATCH 159/202] [Bugfix][ONNX] Skip constant If node generated by PyTorch (#17383) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Bugfix][VTA] Fix FSIM compile error on macOS. VTA FSIM could not be built on macOS, for it leverages malloc.h and memalign, yet both have been deprecated and are not provided by macOS. This issue was captured in #13173. This commit stops including malloc.h in VTA Runtime as stdlib.h has provided functions we need. This commit uses posix_memalign instead of memalign. It is a portable standard function. * Fix format. * [Bugfix][ONNX] Skip constant If node generated by PyTorch This commit adds a check for If nodes for ONNX frontend of Relay to skip the broadcast if the predicate is constant. Sometimes PyTorch to ONNX inserts silly if nodes that produce dynamic ranks, and ONNX frontend of TVM would broadcast the lower dimensions between branches, which is irrational for some cases, e.g. 5×5×3×4 to 5×5×3×4×1. The predicate of silly if might be constant and reasonable to skip to avoid the broadcast problem. This issue was captured in #16898. * Fix format. --- python/tvm/relay/frontend/onnx.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ee7a5d6b329a..8da8a5b11262 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4565,6 +4565,23 @@ def _impl_v1(cls, inputs, attr, params): "Attempting to unify ranks but this may produce incorrect results." ) warnings.warn(warning_msg) + # Skip constant If node to avoid irrational broadcast + if isinstance(inputs[0], tvm.relay.expr.Constant): + predicate = inputs[0].data.asnumpy()[0] + node_name = attr["tvm_custom"]["name"] + warn_msg_begin = f"Predicate of If node {node_name} is always " + if predicate == np.bool_(True): + warnings.warn( + warn_msg_begin + + "true so only then branch would be executed. Removing else branch. " + ) + else_expr = then_expr + elif predicate == np.bool_(False): + warnings.warn( + warn_msg_begin + + "false so only else branch would be executed. Removing then branch. " + ) + then_expr = else_expr if len(then_shape) < len(else_shape): then_expr = _op.broadcast_to_like(then_expr, else_expr) else: @@ -6529,6 +6546,7 @@ def _impl_v11(cls, inputs, attr, params): # compatible operators that do NOT require any conversion. _identity_list = [] + # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted From 36ff1f146c6ad8debcc6675fb2dfc5537fc233dc Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 22 Sep 2024 08:58:24 -0400 Subject: [PATCH 160/202] [3rdparty] Bump FlashInfer for tmp workspace reduction (#17400) This PR bumps FlashInfer to reduce the size of required temporary workspace. --- 3rdparty/flashinfer | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 29 ++++++++++++------- ...tin_paged_attention_kv_cache_flashinfer.py | 2 +- ...me_builtin_paged_attention_kv_cache_tir.py | 2 +- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 0dd801d2027a..1e379898a589 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57 +Subproject commit 1e379898a589cdd4ff18a4621fcbe18d63501545 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 8809a1b0729e..78a7ed1dd1f8 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -57,8 +57,10 @@ namespace relax_vm { constexpr const int kPagedKVCacheMaxBlockDepth = 2; /*! \brief The maximum tree size of a single sequence in tree attention. */ constexpr const int kTreeAttnMaxTreeSize = 256; -/*! \brief The 8MB workspace size for attention auxiliary data. */ -constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024; +/*! \brief The 1MB workspace size for integer attention auxiliary data. */ +constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024; +/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */ +constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ constexpr const int kPagedKVCacheTempPageId = -1; @@ -915,7 +917,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray temp_attn_output_device_; NDArray temp_attn_scores_device_; NDArray merged_attn_scores_device_; - std::vector temp_attn_workspace_; + std::vector temp_int_attn_workspace_; + NDArray temp_float_attn_workspace_; //------------------------------------------- // Below are the auxiliary data structure on CPU. @@ -1089,8 +1092,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_int_attn_workspace_.push_back( + NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); } qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); @@ -1103,8 +1106,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_int_attn_workspace_.push_back( + NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_float_attn_workspace_ = + NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device); } temp_attn_q_device_ = @@ -2324,7 +2329,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (!append_before_attn_) { if (is_chain_on_depths_[0]) { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + temp_float_attn_workspace_, temp_int_attn_workspace_[0], + cur_append_lengths_indptr_host_.as_ndarray(), cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); } @@ -2336,14 +2342,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + page_indptr_on_depths_host_[d].as_ndarray(), last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - page_indptr_on_depths_host_[d].as_ndarray(), + /*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(), static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, num_kv_heads_, head_dim_, page_size_, copy_stream_); } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 2252cb8d9c09..4c25383178ac 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -324,7 +324,7 @@ def set_global_func(): ) fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place") - target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") + target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ kv_cache_transpose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 5ab96caa9bc0..82f85f4b17fa 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -111,7 +111,7 @@ def set_global_func(head_dim, dtype): fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - target = tvm.target.Target("cuda") + target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), From ce461859c5a8dcb0a38b0af83ff206f2f2751e47 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 22 Sep 2024 11:02:58 -0400 Subject: [PATCH 161/202] [KVCache] Attention func accepting over-padded qkv and output NDArray (#17401) This PR enhances the `AttentionWithFusedQKV` function of `PagedKVCache` so that it can now accept input `qkv_data` and `o_data` that have padding along the sequence dimension. We introduce this enhancement to allow more flexibility for the caller of PagedKVCache to decide whether to pad the input qkv/o NDArrays or not. --- src/runtime/relax_vm/paged_kv_cache.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 78a7ed1dd1f8..b6636ae1a7d4 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1755,7 +1755,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_EQ(total_seq_length, qkv_data->shape[0]); + CHECK_LE(total_seq_length, qkv_data->shape[0]); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. @@ -1767,12 +1767,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qkv_data->dtype); NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); + + NDArray qkv_data_view = qkv_data; + NDArray o_data_view = o_data; + if (total_seq_length != qkv_data->shape[0]) { + qkv_data_view = qkv_data.CreateView( + {total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype); + o_data_view = + o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype); + } // Part 2. Split fused qkv and apply rotary embedding to q/k data. if (!rope_ext_factors_.defined()) { - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data, static_cast(rope_mode_ == RoPEMode::kNormal)); } else { - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data, rope_ext_factors_.value()); } @@ -1781,7 +1790,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } // Part 4: perform attention - AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor); + AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor); // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); From 66b21d3c25d93631a91d5b6758eb379c2055c00c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 23 Sep 2024 08:21:20 -0400 Subject: [PATCH 162/202] [Fix][LLVM] Fix getHostCPUFeatures LLVM version cutoff (#17403) This PR fixes the LLVM version cutoff for `llvm::sys::getHostCPUFeatures`. Previously the cutoff version is set to 20.0, assuming that the signature change happens since LLVM 20.0. While actually the signature change happens at 19.0. Reference: * LLVM 18.1.8 https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/include/llvm/TargetParser/Host.h#L56 * LLVM 19.1.0 https://github.com/llvm/llvm-project/blob/llvmorg-19.1.0-rc1/llvm/include/llvm/TargetParser/Host.h#L55 --- src/target/llvm/codegen_llvm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 4c5bea8c9b4b..e21436e556ee 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2315,7 +2315,7 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> st TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") .set_body_typed([]() -> Map { -#if TVM_LLVM_VERSION >= 200 +#if TVM_LLVM_VERSION >= 190 Map ret; auto features = llvm::sys::getHostCPUFeatures(); for (auto it = features.begin(); it != features.end(); ++it) { From 9e2a75d64e937390eab2985743fef47cdeaf3c81 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 23 Sep 2024 21:22:04 +0900 Subject: [PATCH 163/202] [CI] Update image tag to 20240917-153130-9f281758 (#17397) * update image tag to 20240917-153130-9f281758 * increase atol * define custom equal operator to avoid comparison error * try to remove android stuff * skip test_imagenet --- ci/jenkins/docker-images.ini | 20 +++++----- .../python/frontend/pytorch/test_fx_quant.py | 3 ++ tests/python/relax/test_frontend_onnx.py | 5 ++- .../test_tir_transform_simplify.py | 38 +++++++++++++++---- tests/scripts/task_build_hexagon_api.sh | 5 +-- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini index 6e55160521b3..175917f887b7 100644 --- a/ci/jenkins/docker-images.ini +++ b/ci/jenkins/docker-images.ini @@ -17,13 +17,13 @@ # This data file is read during when Jenkins runs job to determine docker images. [jenkins] -ci_arm: tlcpack/ci-arm:20240428-060115-0b09ed018 -ci_cortexm: tlcpack/ci-cortexm:20240428-060115-0b09ed018 -ci_cpu: tlcpack/ci_cpu:20240428-060115-0b09ed018 -ci_gpu: tlcpack/ci-gpu:20240428-060115-0b09ed018 -ci_hexagon: tlcpack/ci-hexagon:20240428-060115-0b09ed018 -ci_i386: tlcpack/ci-i386:20240428-060115-0b09ed018 -ci_lint: tlcpack/ci-lint:20240428-060115-0b09ed018 -ci_minimal: tlcpack/ci-minimal:20240428-060115-0b09ed018 -ci_riscv: tlcpack/ci-riscv:20240428-060115-0b09ed018 -ci_wasm: tlcpack/ci-wasm:20240428-060115-0b09ed018 +ci_arm: tlcpack/ci-arm:20240917-153130-9f281758 +ci_cortexm: tlcpack/ci-cortexm:20240917-153130-9f281758 +ci_cpu: tlcpack/ci_cpu:20240917-153130-9f281758 +ci_gpu: tlcpack/ci-gpu:20240917-153130-9f281758 +ci_hexagon: tlcpack/ci-hexagon:20240917-153130-9f281758 +ci_i386: tlcpack/ci-i386:20240917-153130-9f281758 +ci_lint: tlcpack/ci-lint:20240917-153130-9f281758 +ci_minimal: tlcpack/ci-minimal:20240917-153130-9f281758 +ci_riscv: tlcpack/ci-riscv:20240917-153130-9f281758 +ci_wasm: tlcpack/ci-wasm:20240917-153130-9f281758 diff --git a/tests/python/frontend/pytorch/test_fx_quant.py b/tests/python/frontend/pytorch/test_fx_quant.py index 7f3083a7dcd0..8ed6e1a74797 100644 --- a/tests/python/frontend/pytorch/test_fx_quant.py +++ b/tests/python/frontend/pytorch/test_fx_quant.py @@ -87,6 +87,9 @@ def forward(self, inp): quantize_and_build(model, 300) +@pytest.mark.skip( + reason="Model binary isn't uploaded to S3. See https://github.com/apache/tvm/pull/17397" +) def test_imagenet(): for model_func in [resnet50, efficientnet_b4]: quantize_and_build(model_func(pretrained=True).eval(), 224) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8f4e9881f497..0e7cfbd7c093 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -76,6 +76,7 @@ def check_correctness( inputs: Optional[Dict[str, np.ndarray]] = None, ir_version: int = 8, opset: int = 14, + rtol: float = 1e-7, atol: float = 1e-5, ) -> None: """Run an onnx model in both onnxruntime and TVM through our importer @@ -154,7 +155,7 @@ def check_correctness( # TODO Allow configurable tolerance. # Sometimes None is used to indicate an unused output. if ort_out is not None: - tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=atol) + tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) @pytest.mark.parametrize( @@ -1010,7 +1011,7 @@ def verify_reduce_func(func, data, axis, keepdims): inputs_dict = {"x": data} # Reduction ops accumulate arithmetic errors, so we use a higher tolerance. - check_correctness(model, inputs_dict, opset=11, atol=1e-4) + check_correctness(model, inputs_dict, opset=11, rtol=1e-4, atol=1e-4) for keepdims in [True, False]: verify_reduce_func( diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index f7887bc61137..0b2d5f16d833 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -1021,18 +1021,40 @@ class TestMostRestrictiveConditional(BaseBeforeAfter): then `a >= b` cannot be proven, but can be reduced to `a == b`. """ + class TupleWrapper(tuple): + """ + A custom wrapper for `tuple` to handle element-wise equality comparison + to avoid comparison errors when dealing with objects like `ExprOp`. + See also: https://github.com/apache/tvm/pull/17397 + """ + + def __new__(self, *args): + return super().__new__(self, args) + + def __eq__(self, other): + from tvm.tir.expr import ExprOp + + for a, b in zip(self, other): + if isinstance(a, ExprOp) and isinstance(a, ExprOp): + if not tvm.ir.structural_equal(a, b): + return False + else: + if not a.__eq__(b): + return False + return True + i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"] tir_int = tvm.tir.IntImm("int32", 0) test_case = tvm.testing.parameter( - (i <= tir_int, tir_int <= i, i == tir_int), - (i <= tir_int, i != tir_int, i < tir_int), - (i != tir_int, i <= tir_int, i < tir_int), - (i != tir_int, tir_int <= i, tir_int < i), - (i <= j, j <= i, j == i), - (i <= j, i != j, i < j), - (i != j, i <= j, i < j), - (i != j, j <= i, j < i), + TupleWrapper(i <= tir_int, tir_int <= i, i == tir_int), + TupleWrapper(i <= tir_int, i != tir_int, i < tir_int), + TupleWrapper(i != tir_int, i <= tir_int, i < tir_int), + TupleWrapper(i != tir_int, tir_int <= i, tir_int < i), + TupleWrapper(i <= j, j <= i, j == i), + TupleWrapper(i <= j, i != j, i < j), + TupleWrapper(i != j, i <= j, i < j), + TupleWrapper(i != j, j <= i, j < i), ) @tvm.testing.fixture diff --git a/tests/scripts/task_build_hexagon_api.sh b/tests/scripts/task_build_hexagon_api.sh index 5f811e4e2749..cff6d7a6ba59 100755 --- a/tests/scripts/task_build_hexagon_api.sh +++ b/tests/scripts/task_build_hexagon_api.sh @@ -41,10 +41,7 @@ fi mkdir -p build cd build -cmake -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-28 \ - -DUSE_ANDROID_TOOLCHAIN="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ - -DUSE_HEXAGON_ARCH=v68 \ +cmake -DUSE_HEXAGON_ARCH=v68 \ -DUSE_HEXAGON_SDK="${HEXAGON_SDK_ROOT}" \ -DUSE_HEXAGON_TOOLCHAIN="${HEXAGON_TOOLCHAIN}" \ -DUSE_OUTPUT_BINARY_DIR="${output_directory}" \ From 44808b41c803a3f08a4f43a6455ae0b0df1ac3ba Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 23 Sep 2024 05:23:40 -0700 Subject: [PATCH 164/202] [WASM] Implement concat embeddings (#17404) * [WASM] Implement concat embeddings * Make concatEmbeddings optional for backward compatibility --- src/target/source/codegen_webgpu.cc | 1 + web/emcc/wasm_runtime.cc | 46 +++++++++++++++++++++++++++++ web/src/runtime.ts | 38 +++++++++++++++++++++++- 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 83079a9f0756..1d1df91dc4a4 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -125,6 +125,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re name_supply_->ReserveName("var"); name_supply_->ReserveName("let"); name_supply_->ReserveName("const"); + name_supply_->ReserveName("std"); // skip the first underscore, so SSA variable starts from name_supply_->FreshName("v_"); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 2f7135595843..9744750b80db 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -173,5 +173,51 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRe } *ret = Array(data); }); + +NDArray ConcatEmbeddings(const std::vector& embeddings) { + // Get output shape + int64_t hidden_size = embeddings[0]->shape[1]; + DLDataType dtype = embeddings[0]->dtype; + DLDevice device = embeddings[0]->device; + int seqLen = 0; + for (int i = 0; i < embeddings.size(); ++i) { + ICHECK_EQ(embeddings[i]->ndim, 2); + ICHECK_EQ(embeddings[i]->shape[1], hidden_size); + seqLen += embeddings[i]->shape[0]; + } + + // Create output + std::vector shape; + shape.push_back(seqLen); + shape.push_back(hidden_size); + NDArray result = NDArray::Empty(shape, dtype, device); + + // Copy + int offset = 0; + for (int i = 0; i < embeddings.size(); i++) { + const DLTensor& copy_src = *(embeddings[i].operator->()); + const DLTensor* p_copy_dst = result.operator->(); + DLTensor copy_dst = *p_copy_dst; + copy_dst.shape = embeddings[i]->shape; + copy_dst.byte_offset = + offset * hidden_size * ((embeddings[i]->dtype.bits * embeddings[i]->dtype.lanes + 7) / 8); + NDArray::CopyFromTo(©_src, ©_dst); + offset += embeddings[i]->shape[0]; + } + + return result; +} + +// Concatenate n NDArrays +TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector embeddings; + for (int i = 0; i < args.size(); ++i) { + ICHECK_EQ(args[i].type_code(), kTVMNDArrayHandle); + embeddings.push_back(args[i]); + } + NDArray result = ConcatEmbeddings(std::move(embeddings)); + *ret = result; +}); + } // namespace runtime } // namespace tvm diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 600a9b857f03..8546cab773ff 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -174,6 +174,7 @@ class RuntimeContext implements Disposable { applyRepetitionPenalty: PackedFunc; applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; + concatEmbeddings: PackedFunc | undefined; private autoDisposeScope: Array> = []; @@ -199,6 +200,11 @@ class RuntimeContext implements Disposable { this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); + try { + this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); + } catch { + // TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility. + } } dispose(): void { @@ -223,6 +229,7 @@ class RuntimeContext implements Disposable { this.applyRepetitionPenalty.dispose(); this.applyPresenceAndFrequencyPenalty.dispose(); this.applySoftmaxWithTemperature.dispose(); + this.concatEmbeddings?.dispose(); } beginScope(): void { @@ -575,7 +582,10 @@ export class NDArray implements Disposable { * @param data The source data array. * @returns this */ - copyFrom(data: NDArray | Array | Float32Array): this { + copyFrom( + data: NDArray | Array | Float32Array | Float64Array | + Int32Array | Int8Array | Uint8Array | Uint8ClampedArray + ): this { if (data instanceof NDArray) { this.lib.checkCall( (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( @@ -608,6 +618,8 @@ export class NDArray implements Disposable { buffer = Int8Array.from(data).buffer; } else if (this.dtype === "uint8") { buffer = Uint8Array.from(data).buffer; + } else if (this.dtype === "uint32") { + buffer = Uint32Array.from(data).buffer; } else { throw new Error("Unsupported data type " + this.dtype); } @@ -1906,6 +1918,30 @@ export class Instance implements Disposable { return this.ctx.arrayConcat(...listOfArrays) as TVMArray; } + /** + * Join a sequence of NDArrays that represent embeddings. + * @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size). + * @returns An NDArray of shape (\sum_{i} {m}, hidden_size) + */ + concatEmbeddings(embeddings: Array): NDArray { + // 1. Check shape validity + const hidden_size = embeddings[0].shape[1]; + embeddings.forEach((input) => { + if (input.shape.length !== 2 || input.shape[1] !== hidden_size) { + throw new Error("Expect embeddings to concatenate have shape (m_i, hidden_size)."); + } + }) + + // 2. Call global func + if (this.ctx.concatEmbeddings === undefined) { + throw new Error( + "Global function tvmjs.runtime.ConcatEmbeddings was " + + "not found, but called concatEmbeddings." + ); + } + return this.ctx.concatEmbeddings(...embeddings) as NDArray; + } + /** * Create a {@link TVMString} that can be consumed by runtime. * From 48d3ada2750959fb06cbb555a3491dbf41a3c155 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 23 Sep 2024 06:17:55 -0700 Subject: [PATCH 165/202] [TIR, TVMScript] Add TIR - Triton integration (#17395) * [TIR, TVMScript] Add TIR - Triton integration Added a macro `T.call_triton` in TIR script parser, which expands to AOT compilation of the kernel and the host TIR code to launch the kernel. --- python/tvm/relax/vm_build.py | 14 +- python/tvm/script/ir_builder/ir/__init__.py | 2 + python/tvm/script/ir_builder/ir/ir.py | 58 ++++++- .../script/ir_builder/tir/external_kernel.py | 141 ++++++++++++++++++ python/tvm/script/ir_builder/tir/ir.py | 3 +- python/tvm/script/ir_builder/tir/triton.py | 115 ++++++++++++++ src/script/ir_builder/ir/ir.cc | 32 +++- .../contrib/test_tir_triton_integration.py | 119 +++++++++++++++ 8 files changed, 477 insertions(+), 7 deletions(-) create mode 100644 python/tvm/script/ir_builder/tir/external_kernel.py create mode 100644 python/tvm/script/ir_builder/tir/triton.py create mode 100644 tests/python/contrib/test_tir_triton_integration.py diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 243488e5d83f..9fd7a7428588 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -243,13 +243,25 @@ def _vmlink( if ext_libs is None: ext_libs = [] lib = None + relax_ext_libs = [] + tir_ext_libs = [] if tir_mod is not None and len(tir_mod.get_global_vars()) > 0: lib = tvm.build( tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib), ) - return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore + for ext_mod in ext_libs: + if ext_mod.type_key == "cuda": + tir_ext_libs.append(ext_mod) + else: + relax_ext_libs.append(ext_mod) + if lib is not None: + for mod in tir_ext_libs: + lib.import_module(mod) + elif len(tir_ext_libs) > 0: + print("Warning: No TIR module is found, but external modules for TIR are provided.") + return Executable(_ffi_api.VMLink(builder, target, lib, relax_ext_libs, params)) # type: ignore def build( diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index fdf44b2b7918..f604026a1311 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -21,6 +21,8 @@ def_function, ir_module, module_attrs, + module_get_attr, + module_set_attr, module_global_infos, lookup_vdevice, vdevice, diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index d35d73678b47..05ee26e832fb 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,7 +16,7 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" -from typing import Dict, List +from typing import Dict, List, Optional from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, VDevice, DummyGlobalInfo from tvm.runtime import Object as tvm_Object @@ -77,14 +77,66 @@ def def_function(func_name: str, func: BaseFunc) -> None: return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member -def module_attrs(attrs: Dict[str, tvm_Object]) -> None: +def module_attrs(attrs: Dict[str, tvm_Object], allow_overwrite=False) -> None: """Specify the attrs of the ir_module frame. Parameters ---------- attrs: Dict[str, Object] The module attrs. + allow_overwrite: bool + Whether allow overwrite the existing attrs. """ - return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.ModuleAttrs(attrs, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member + + +def current_ir_module() -> IRModuleFrame: + """Get the current ir_module frame. + Returns + ------- + frame: IRModuleFrame + The current frame. + """ + return _ffi_api.CurrentIRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_get_attrs() -> Dict[str, tvm_Object]: + """Get the attrs of the ir_module frame. + Returns + ------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleGetAttrs() # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_get_attr(attr_key: str) -> Optional[tvm_Object]: + """Get the specified attr of the ir_module frame. + Parameters + ---------- + attr_key: str + The key of the attr to be retrieved. + Returns + ------- + attr: Optional[Object] + The specified module attr or None if not found. + """ + return _ffi_api.ModuleGetAttr(attr_key) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_set_attr( + attr_key: str, attr_value: Optional[tvm_Object], allow_overwrite: bool = False +) -> None: + """Set the specified attr of the ir_module frame. + Parameters + ---------- + attr_key: str + The key of the attr to be set. + attr_value: Optional[Object] + The value of the attr to be set. + allow_overwrite: bool + Whether allow overwrite the existing attr. + """ + return _ffi_api.ModuleSetAttr(attr_key, attr_value, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None: diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py b/python/tvm/script/ir_builder/tir/external_kernel.py new file mode 100644 index 000000000000..8c2467fad330 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/external_kernel.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External kernel integration fro TIR""" +import json +import logging +import tempfile +from typing import Any, Dict, List, Tuple, Union + +from tvm import __version__ as tvm_version +from tvm import tir +from tvm.runtime import Module, load_module + + +class BaseKernel: + """Base class for external kernels.""" + + def compile_to_device_module( + self, launch_args, *args, **kwargs + ) -> Tuple[str, Module, List[Any]]: + """Compile the kernel to a device module.""" + raise NotImplementedError() + + def _format_tvm_module_metadata(self, kernel_name, arg_types, launch_param_tags): + """Format the TVM module metadata.""" + tvm_metadata = """{{ + "tvm_version": "{version}", + "func_info": {{ + "{kernel_name}": {{ + "name": "", + "arg_types": {arg_types}, + "launch_param_tags": {launch_param_tags} + }} + }} + }}""".format_map( + { + "version": tvm_version, + "kernel_name": kernel_name, + "arg_types": json.dumps(arg_types), + "launch_param_tags": json.dumps(launch_param_tags), + } + ) + return tvm_metadata + + def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_name): + """ + Create a CUDA module from PTX and metadata. + + Parameters + ---------- + ptx : str + The PTX code of the kernel. + + kernel_arg_types : List[str] + The types of the kernel arguments. + + launch_param_tags : List[str] + The tags of the launch parameters. + + kernel_name : str + The name of the kernel. + + Returns + ------- + kernel_module : Module + The CUDA module. + """ + tvm_metadata = self._format_tvm_module_metadata( + kernel_name, kernel_arg_types, launch_param_tags + ) + with tempfile.TemporaryDirectory() as temp_dir: + ptx_path = f"{temp_dir}/{kernel_name}.ptx" + with open(ptx_path, "w") as f: + f.write(ptx) + with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f: + f.write(tvm_metadata) + kernel_module = load_module(ptx_path) + return kernel_module + + +def call_kernel( + kernel, + launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]], + *args: List[Any], + **kwargs: Dict[str, Any], +): + """ + Call an external kernel. + + Parameters + ---------- + kernel : Any + The external kernel to call. + + launch_args : List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]] + The launch arguments. A list of integers for grid size, block size, and shared memory size. + The actual requirements depend on the kernel. + + args : List[tir.PrimExpr] + The arguments to pass to the kernel. + + kwargs : Dict[str, Any] + Additional keyword arguments to pass to the kernel or compilation. + """ + from ..ir import module_get_attr, module_set_attr # pylint: disable=import-outside-toplevel + from .ir import call_packed # pylint: disable=import-outside-toplevel + + kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}" + if kernel_type == "triton.runtime.jit.JITFunction": + from .triton import TritonKernel # pylint: disable=import-outside-toplevel + + kernel = TritonKernel(kernel) + else: + raise ValueError("Unsupported kernel type {}".format(kernel_type)) + + kernel_name, kernel_module, runtime_args = kernel.compile_to_device_module( + launch_args, *args, **kwargs + ) + + # Attach the kernel module to the current IRModule + external_mods: List[Module] = module_get_attr("external_mods") or [] + kernel_exists = any([mod.implements_function(kernel_name) for mod in external_mods]) + if kernel_exists: + logging.debug("Kernel %s already exists in the IRModule", kernel_name) + else: + external_mods.append(kernel_module) + module_set_attr("external_mods", external_mods, True) + return call_packed(kernel_name, *runtime_args) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index bdbd6e2cdac0..f7face272de5 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -83,6 +83,7 @@ from tvm.tir.generic import cast from . import _ffi_api, frame +from .external_kernel import call_kernel # pylint: enable=unused-import @@ -1943,7 +1944,6 @@ def wrapped(*args, **kwargs): tvm_call_packed_lowered = call_packed_lowered tvm_call_cpacked_lowered = call_cpacked_lowered - # pylint: enable=invalid-name @@ -2255,4 +2255,5 @@ def wrapped(*args, **kwargs): "Range", "vscale", "get_active_lane_mask", + "call_kernel", ] diff --git a/python/tvm/script/ir_builder/tir/triton.py b/python/tvm/script/ir_builder/tir/triton.py new file mode 100644 index 000000000000..2d37d93a6dd8 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/triton.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Triton kernel integration with TIR""" + +from typing import Tuple, List, Union, Any, Dict + +import triton +from triton.runtime.jit import type_canonicalisation_dict +from tvm import tir +from tvm.topi.utils import get_const_int +from tvm.runtime import Module +from .external_kernel import BaseKernel + + +class TritonKernel(BaseKernel): + """A kernel from Triton JIT function. + + This class bridges the Triton kernel with TVM runtime. The compilation includes the following + steps: + - Deduce the kernel signature and generate the Triton kernel + - Embed the compiled kernel into the current IRModule as an external module + - Generate a call to the Triton kernel following its calling convention via call_packed. + """ + + def __init__(self, func): + self.func = func + + def compile_to_device_module( + self, + launch_args: List[Union[int, tir.PrimExpr]], + *args: List[Any], + **kwargs: Dict[str, Any], + ) -> Tuple[str, Module, List[Any]]: + """Compile the kernel to a device module. + + Parameters + ---------- + launch_args : List[int] + The grid size of the kernel. A list of one to three expressions, representing the number + of + "blockIdx.x", "blockIdx.y", and "blockIdx.z" respectively. + + args : List[Any] + Arguments to the kernel function. + + kwargs : Dict[str, Any] + Additional options for the kernel compilation. + """ + triton_kernel, kernel_args = self._generate_triton_kernel(self.func, *args, **kwargs) + kernel_metadata = triton_kernel.metadata + ptx = triton_kernel.asm["ptx"] + assert kernel_metadata.num_ctas == 1, "Cluster is not supported" + num_warps = kernel_metadata.num_warps + grid = launch_args + launch_param_tags = ["threadIdx.x"] + ["blockIdx.x", "blockIdx.y", "blockIdx.z"][ + : len(grid) + ] + launch_args = [num_warps * 32] + list(grid) + kernel_arg_types = [arg.dtype for arg in kernel_args] + if triton_kernel.metadata.shared > 0: + # Add shared memory size to the launch arguments + launch_param_tags.append("tir.use_dyn_shared_memory") + launch_args.append(triton_kernel.metadata.shared) + + kernel_module = self._create_cuda_module( + ptx, kernel_arg_types, launch_param_tags, triton_kernel.name + ) + + return triton_kernel.name, kernel_module, kernel_args + launch_args + + def _generate_triton_kernel( + self, func, *args, **kwargs + ) -> Tuple["triton.compiler.CompiledKernel", List[tir.PrimExpr]]: + """Deduce the kernel signature and generate the Triton kernel""" + + kernel_params = func.params + assert len(kernel_params) == len( + args + ), f"Number of arguments does not match, expected {len(kernel_params)}, got {len(args)}" + + signature = {} + constants = {} + kernel_args = [] # Arguments to invoke the kernel + for i, arg in enumerate(args): + if kernel_params[i].is_constexpr: + constants[kernel_params[i].name] = get_const_int(arg) + continue + if arg.dtype == "handle": + assert isinstance(arg, tir.Var) + elem_type = arg.type_annotation.element_type.dtype + pointer_type = "*" + type_canonicalisation_dict[elem_type] + signature[kernel_params[i].name] = pointer_type + else: + signature[kernel_params[i].name] = type_canonicalisation_dict[arg.dtype] + kernel_args.append(arg) + + # TODO: Support default argument in the kernel + # TODO: Add specialization for aligned buffer pointers + source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature) + compiled = triton.compiler.compile(source, options=kwargs) + return compiled, kernel_args diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 2f2785ca4440..0fb4b256351b 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -88,17 +88,43 @@ void DefFunction(const String& func_name, const BaseFunc& func) { gv->checked_type_ = func->checked_type_; } -void ModuleAttrs(Map attrs) { +void ModuleAttrs(Map attrs, bool allow_overwrite) { if (IRBuilder::IsInScope()) { // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); - if (!frame->attrs.empty()) { + if (!allow_overwrite && !frame->attrs.empty()) { LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; } frame->attrs = attrs; } } +Optional ModuleGetAttr(const String& key) { + if (IRBuilder::IsInScope()) { + IRModuleFrame frame = FindModuleFrame(); + if (frame->attrs.find(key) != frame->attrs.end()) { + return frame->attrs[key]; + } + } + return NullOpt; +} + +void ModuleSetAttr(const String& key, const Optional& value, bool allow_override) { + if (IRBuilder::IsInScope()) { + IRModuleFrame frame = FindModuleFrame(); + if (!allow_override && frame->attrs.find(key) != frame->attrs.end() && value.defined()) { + LOG(FATAL) << "ValueError: Duplicate module attr " << key; + } + if (value.defined()) { + frame->attrs.Set(key, value.value()); + } else { + frame->attrs.erase(key); + } + } else { + LOG(FATAL) << "ValueError: Currently in in the scope of a module."; + } +} + void ModuleGlobalInfos(Map> global_infos) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); @@ -143,6 +169,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); TVM_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py new file mode 100644 index 000000000000..522351f3dc55 --- /dev/null +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import sys + +import tvm +from tvm.script import tir as T +from tvm.script import relax as R +from tvm.script import ir as I +from tvm import relax +from tvm.relax.frontend import nn +import tvm.testing +import pytest + +try: + import triton + import triton.language as tl +except ImportError: + pytestmark = pytest.skip("Triton is not available", allow_module_level=True) + + +@tvm.testing.requires_cuda +def test_tir_triton_integration(): + @triton.jit + def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + ): + """Triton vector add kernel from its tutorial.""" + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + @I.ir_module + class Module: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None: + T.func_attr({"global_symbol": "add"}) + m = T.int64() + x = T.match_buffer(x_handle, (m,), "float32") + y = T.match_buffer(y_handle, (m,), "float32") + output = T.match_buffer(output_handle, (m,), "float32") + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + BLOCK_SIZE = T.meta_var(64) + T.call_kernel( + add_kernel, + (T.ceildiv(m, BLOCK_SIZE),), + x.data, + y.data, + output.data, + m, + BLOCK_SIZE, + ) + + @R.function + def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): + m = T.int64() + with R.dataflow(): + output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32")) + R.output(output) + return output + + @I.ir_module + class Parsed: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): + m = T.int64() + x = T.match_buffer(x_handle, (m,)) + y = T.match_buffer(y_handle, (m,)) + output = T.match_buffer(output_handle, (m,)) + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + T.call_packed( + "add_kernel", + x.data, + y.data, + output.data, + m, + 128, + (m + T.int64(64) - T.int64(1)) // T.int64(64), + ) + + tvm.ir.assert_structural_equal(Module["add"], Parsed["add"]) + assert len(Module.get_attr("external_mods")) == 1 + + device = tvm.cuda(0) + x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + output_np = x_nd.numpy() + y_nd.numpy() + + with tvm.target.Target("cuda"): + lib = relax.build(Module) + output_nd = tvm.runtime.relax_vm.VirtualMachine(lib, device)["main"](x_nd, y_nd) + tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5) From 30fb16a5e1d564ffa8533cf154c0ba2ea06dfd43 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 23 Sep 2024 06:34:46 -0700 Subject: [PATCH 166/202] [TVMjs] Modify web package description (#17405) --- web/package-lock.json | 12 ++++++------ web/package.json | 12 +++++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/web/package-lock.json b/web/package-lock.json index 75efcbcc7b70..561ba770913f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.17.0-dev0", + "version": "0.18.0-dev0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.17.0-dev0", + "version": "0.18.0-dev0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", @@ -14,7 +14,7 @@ "@types/node": "^20.4.5", "@typescript-eslint/eslint-plugin": "^5.59.6", "@typescript-eslint/parser": "^5.59.6", - "@webgpu/types": "^0.1.40", + "@webgpu/types": "^0.1.42", "eslint": "^8.41.0", "jest": "^26.0.1", "rollup": "^2.56.2", @@ -1766,9 +1766,9 @@ } }, "node_modules/@webgpu/types": { - "version": "0.1.40", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.40.tgz", - "integrity": "sha512-/BBkHLS6/eQjyWhY2H7Dx5DHcVrS2ICj9owvSRdgtQT6KcafLZA86tPze0xAOsd4FbsYKCUBUQyNi87q7gV7kw==", + "version": "0.1.46", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.46.tgz", + "integrity": "sha512-2iogO6Zh0pTbKLGZuuGWEmJpF/fTABGs7G9wXxpn7s24XSJchSUIiMqIJHURi5zsMZRRTuXrV/3GLOkmOFjq5w==", "dev": true }, "node_modules/abab": { diff --git a/web/package.json b/web/package.json index 710185c5bcbc..a4e5d7ac086d 100644 --- a/web/package.json +++ b/web/package.json @@ -1,11 +1,21 @@ { "name": "tvmjs", - "displayName": "TVM Wasm JS runtime", + "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", + "homepage": "https://github.com/apache/tvm/tree/main/web", "version": "0.18.0-dev0", "files": [ "lib" ], + "repository": { + "type": "git", + "url": "git+https://github.com/apache/tvm/tree/main/web" + }, + "keywords": [ + "llm", + "large language model", + "machine learning" + ], "main": "lib/index.js", "types": "lib/index.d.ts", "scripts": { From dfd9bd581d2d866d552c8e099568c6127aa3f971 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 24 Sep 2024 08:33:19 +0800 Subject: [PATCH 167/202] [Doc] Update Architecture Overview (#17402) * [Doc] Update Architecture Overview Update and reorganize architecture documentation This commit updates the architecture documentation by removing outdated files and reorganizing the content. It also updates related sections in the deep dive and developer tutorial. * lint * lint --- docs/arch/benchmark.rst | 137 ---- docs/arch/convert_layout.rst | 269 ------ docs/arch/frontend/tensorflow.rst | 254 ------ docs/arch/hybrid_script.rst | 100 --- docs/arch/index.rst | 218 ++--- docs/arch/inferbound.rst | 763 ------------------ docs/arch/microtvm_design.rst | 357 -------- docs/arch/microtvm_project_api.rst | 150 ---- docs/arch/model_library_format.rst | 171 ---- docs/arch/relay_intro.rst | 206 ----- docs/arch/relay_op_strategy.rst | 282 ------- docs/arch/virtual_machine.rst | 410 ---------- docs/deep_dive/relax/index.rst | 2 +- docs/deep_dive/tensor_ir/index.rst | 2 +- docs/dev/tutorial/codebase_walkthrough.rst | 2 +- docs/index.rst | 2 +- docs/reference/langref/relay_expr.rst | 4 +- docs/topic/microtvm/index.rst | 7 - .../tune_network_arm.py | 1 - .../tune_network_cuda.py | 1 - .../tune_network_mali.py | 1 - .../tune_network_x86.py | 1 - .../how_to/work_with_microtvm/micro_tvmc.sh | 2 +- 23 files changed, 81 insertions(+), 3261 deletions(-) delete mode 100644 docs/arch/benchmark.rst delete mode 100644 docs/arch/convert_layout.rst delete mode 100644 docs/arch/frontend/tensorflow.rst delete mode 100644 docs/arch/hybrid_script.rst delete mode 100644 docs/arch/inferbound.rst delete mode 100644 docs/arch/microtvm_design.rst delete mode 100644 docs/arch/microtvm_project_api.rst delete mode 100644 docs/arch/model_library_format.rst delete mode 100644 docs/arch/relay_intro.rst delete mode 100644 docs/arch/relay_op_strategy.rst delete mode 100644 docs/arch/virtual_machine.rst diff --git a/docs/arch/benchmark.rst b/docs/arch/benchmark.rst deleted file mode 100644 index 8217a4feb7df..000000000000 --- a/docs/arch/benchmark.rst +++ /dev/null @@ -1,137 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -******************************** -Benchmark Performance Log Format -******************************** -This page details schema v0.1 for a unified benchmark log format. This schema will allow easier cross-references with other frameworks/runs, experiment reproduction, data for nightly perf regression, and the separation of logging/visualization efforts. - -Log Format Overview -~~~~~~~~~~~~~~~~~~~ - -For simplicity, we suggest prioritizing the fields `workload`, `engine`, `hardware` `runtime_ms_mean`, and `runtime_ms_std`. For finer-grained logging, one may additionally propagate the `*_config` fields. - -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| header | examples | category | notes/justification | -+=======================+==============================================================================================================================================================================+==============+==============================================================================+ -| workload | resnet-18 | workload | name of workload | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| engine | "tvm" / "onnxruntime" | compiler | | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| hardware | "gcp-c2-standard-16" | hardware | descriptor of target hardware environment | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| runtime_ms_mean | 12.452 | statistics | | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| runtime_ms_std | 5.3 | statistics | | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| timestamp | 1572282699.6 | metadata | indicates when this record is logged | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| schema\_version | "0.1" | metadata | ensure reproducibility as we iterate on this schema | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| metadata | { "docker\_tag":"gcr.io/.../0a680", ... } | metadata | ``docker_tag`` is optional | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| workload\_args | {“input\_name”: "Input3", “input\_shape”: [list\_of\_shape], “data\_layout”: NHCW} | workload | | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| workload\_metadata | {"class": "vision","doc\_url": "``https://github.com/.../README.md``", "opset": 7,"type": "body\_analysis","url": "``https://onnxzoo...ferplus.tar.gz``", "md5": "07fc7..."} | workload | source of workload | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| engine\_version | "1.0.5" | compiler | use semvar format | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| engine\_config | {“llvm”: “llvm-8”, “nvcc”: 10.1, "accelerator": "MLAS", "relay_opt_level": 3, "tvm_target":"llvm -mcpu=cascadelake"} | compiler | fields are optionally specified | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| compilation\_config | {"opt_level": 3, "layer_schedules":[]/ } | compiler | fields are optionally specified | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| software\_config | {"os": "ubuntu:18.04","pip": { "docker": "4.1.0", "gitpython": "3.0.4", "numpy": "1.17.4", "onnx": "1.6.0"}, “cudnn”: “cudnn-8”, "cuda_driver”: “480.10.1”} | backend | env dependency list | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| runtime\_config | {"num_cpu_threads": 3} | backend | info on non-hardware, non-software metadata | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| hardware\_config | {"cpu_count": 16, "cloud_machine_type":"c2-standard-16", "memory_GB":64} | hardware | json descriptor of target hardware environment | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| execution\_config | {“number”: 1, “repeat”: 10, “min\_repeat\_ms”, 0} | statistics | workload execution parameters | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| metrics | {“accuracy”: 48.5,“compilation_ms_mean”: 12} | statistics | other metrics | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ -| runtime_raw | [{"runtime_ms": 12, ...}, {"runtime_ms":13,...},...] | statistics | optional raw metrics array | -+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+------------------------------------------------------------------------------+ - - - -Storage format -~~~~~~~~~~~~~~ -Currently we're prototyping benchmark data as JSON objects for extensibility and convenience, especially in early versions of the schema. However, as we scale up benchmark aggregation and stabilize parameters, we anticipate switching to a columnar format, such as Arrow or Parquet. - -Here is sample data encoded as JSON: - -:: - - { - "workload":"arcface_resnet100", - "engine":"tvm", - "hardware":"gcp-c2-standard-16", - "runtime_ms_mean":109.43004820081924, - "runtime_ms_std":0.09078385126800587, - "timestamp":"20191123003411", - "schema_version":"0.1", - "metadata":{ - "docker_tag":"tlcpack/ci-gpu:v0.53" - }, - "workload_args":{ - "input_shape_dict":{ - "data":[ - 1, - 3, - 112, - 112 - ] - }, - "input_type_dict":{ - "data":"float32" - }, - "input_value_dict":{} - }, - "workload_metadata":{ - "class":"vision", - "doc_url":"https://github.com/onnx/models/blob/main/vision/body_analysis/arcface/README.md", - "md5":"66074b860f905295aab5a842be57f37d", - "opset":8, - "type":"body_analysis", - "url":"https://s3.amazonaws.com/onnx-model-zoo/arcface/resnet100/resnet100.tar.gz" - }, - "engine_version":"1.0.0", - "engine_config":{}, - "compilation_config":{ - "relay_opt_level": 3 - }, - "software_config":{ - "os":"ubuntu:18.04", - "pip":{ - "docker":"4.1.0", - "gitpython":"3.0.4", - "numpy":"1.17.4", - "onnx":"1.6.0" - } - }, - "runtime_config":{}, - "hardware_config":{ - "cloud_machine_type":"c2-standard-16", - "cloud_provider":"GCP", - "cpu_count":16, - "cpu_platform":"Intel Cascade Lake", - "memory_GB":64 - }, - "execution_config":{}, - "metrics":{} - } diff --git a/docs/arch/convert_layout.rst b/docs/arch/convert_layout.rst deleted file mode 100644 index 51917fce44df..000000000000 --- a/docs/arch/convert_layout.rst +++ /dev/null @@ -1,269 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at -.. http://www.apache.org/licenses/LICENSE-2.0 -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -=================== -Convert Layout Pass -=================== -**Author**: `Animesh Jain `_ - -************* -1. Background -************* - -Data layout format describes how the data is laid out in the memory. For example, Tensorflow framework default data layout for convolution operator is NHWC, i.e, the data is 4-dimensions and is laid out in row-major format with N being the first dimension and C being the last dimension. Data layout has a major role in model performance, significantly affecting spatial and temporal locality. For example, Intel x86 backend in TVM prefers layout as NCHWc where the C dimension is tiled in 2 dimensions to exploit data locality efficiently. Similarly, CUDA backend prefers the data layout to be in NCHW format. - -Essentially, TVM has to deal with data layouts throughout the compiler toolchain - Framework parsers, Relay layout transformations, and TOPI schedules. As we move towards third-party codegen integration, which might have their own data layout restrictions, handling layouts at all levels in TVM toolchain is going to become even more challenging. Therefore, we developed a new Relay pass - **ConvertLayout** -- to reduce some of the complications that arise due to layout handling. - -If you directly want to understand the usage of ConvertLayout Pass, directly jump to Section 4 - Usage. - -************************** -2. Motivation and Overview -************************** - -Let's look at a simple scenario to understand the complications that arise due to different layouts - Suppose we want to compile a Tensorflow NHWC graph for an ARM edge device. But, suppose we currently support only NCHW schedules in TOPI for ARM. So, there is a mismatch between framework layout and TOPI-supported layout. One way to deal with this mismatch is to insert layout transforms before each and after convolution, such that resulting convolution has NCHW input data layout and can use TOPI schedules. However, this can lead to performance degradation because of the presence of too many layout transforms. - -We encountered similar problems in other use cases as well - -- No way to run TFLite graphs on Nvidia GPUs. TOPI has NCHW-only schedules for GPUs. -- Ever-complicating logic in AlterOpLayout for convolution to support different pairs of layout transformations. -- Sub-optimal performance for TF graphs due to extra layout transforms. -- Complication in third-party codegen integrations like TensorRT that prefers data layout to be in one format. - -To solve these problems, we introduced *ConvertLayout* pass that sets up the infrastructure to change the data layout of the whole graph with minimal number of data layout transforms. In ideal cases, we will have only 2 layout transforms for data, one at the start and one at the end. An example to show the transformation is below - - -.. code-block:: python - - # Original graph - 2 convolutions in NHWC format. - fn (%x: Tensor[(1, 56, 56, 64), float32], %weight1: Tensor[(3, 3, 64, 32), float32], %weight2: Tensor[(3, 3, 32, 32), float32]) { - %0 = nn.conv2d(%x, %weight1, padding=[1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO"); - %1 = nn.relu(%0); - %2 = nn.conv2d(%1, %weight2, padding=[1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO"); - nn.relu(%2) - } - - # After ConvertLayout - For data, there is a transform at the start and at the end. - # For weights, there are transforms to adapt to NCHW layout. These will be removed by FoldConstant pass. - fn (%x: Tensor[(1, 56, 56, 64), float32], %weight1: Tensor[(3, 3, 64, 32), float32], %weight2: Tensor[(3, 3, 32, 32), float32]) { - %0 = layout_transform(%x, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 64, 56, 56), float32] */; - %1 = layout_transform(%weight1, src_layout="HWIO", dst_layout="OIHW") /* ty=Tensor[(32, 64, 3, 3), float32] */; - %2 = nn.conv2d(%0, %1, padding=[1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] */; - %3 = nn.relu(%2) /* ty=Tensor[(1, 32, 56, 56), float32] */; - %4 = layout_transform(%weight2, src_layout="HWIO", dst_layout="OIHW") /* ty=Tensor[(32, 32, 3, 3), float32] */; - %5 = nn.conv2d(%3, %4, padding=[1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] */; - %6 = nn.relu(%5) /* ty=Tensor[(1, 32, 56, 56), float32] */; - layout_transform(%6, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 56, 56, 32), float32] */ - } - - -********* -3. Design -********* - -Before delving into ConvertLayout pass, let's categorize the operators into 3 categories based on their sensitivity to data layouts. This categorization will be useful later to understand Convertlayout pass details. - -- **Layout agnostic** - Relu, pow etc. These operators are not affected, neither functionality nor performance, by data layouts. -- **Lightly-layout sensitive** - pad, concatenate, reduce ops like sum etc. These operators have some attributes that are functionally affected if we do a layout transformation before them. However, performance-wise, the difference is not significant. For these operators, it is beneficial to just adapt to the previous operator output data layout. -- **Heavily-layout sensitive** - Convolution, conv2d_transpose etc. These operators are heavily affected, both functionally and performance-wise, by data layouts. They also have data layout as the op attribute. Typically, it is beneficial to modify the input data layouts for these operators (if its not a performant data layout), while the rest of *layout agnostic* and *lightly-layout sensitive* operators adapt to the layout governed by the output of these *heavliy-layout sensitive* operators. - - -Let us now look at two relevant Relay operator properties. Each relay operator has properties, like InferType, that can be defined by a TVM developer. Typically, a Relay pass traverses the graph operator-by-operator and reads these operator properties. For example, InferType pass looks at the InferType property of on operator, determines its output shape and type, and then passes it to the next operator InferType property. Similarly, in our context, we have 2 such properties - *FTVMConvertLayout* and *FInferCorrectLayout*. ConvertLayout pass traverses the graph and looks at these 2 properties along with an automatic layout transform insertion module to handle data layouts. So, the whole process can be broken down into 3 steps: - -- Run FTVMConvertLayout property - This allows the developers to transform the original Relay expr into a new Relay expr with new layouts, allowing user-defined layout alteration. There is a python callback for developer's ease. This is used only for heavily-layout sensitive operators. -- Run FTVMInferCorretLayout property - We can view this as layout inference. It looks at the original input layout and the new input layouts, which are either coming from previous operator or from the FTVMConvertLayout modified expr (if it was used). This can be used by lightly-layout sensitive operators to adapt its attributes to new data layouts. Layout inference happens for each operator. -- Automatic insertion of layout transforms - The previous step - layout inference - sets the new layout for the input exprs. If these layouts are different from the original layouts, then this component automatically inserts a layout transform. Therefore, a developer does not need to do anything for this component. - -These steps happen for each operator in sequence, where ConvertLayout pass keeps on passing the new layouts to the next operator properties, finally resulting in modifying the whole graph operator-by-operator. Now, let's look at a couple of examples of how to define the two properties. - -**FTVMConvertLayout - Python callback for layout alteration** - This is used for *heavily-layout sensitive* operators. For example, one can return a new convolution operator with new data and kernel layout. The other 2 components will infer layout and insert layout transforms if needed. One example for convolution operator is as follows where we are converting to NCHW layout. - -.. code-block:: python - - @reg.register_convert_op_layout("nn.conv2d") - def convert_conv2d(attrs, inputs, tinfos, desired_layouts): - """Convert Layout pass registration for conv2d op. - - Parameters - ---------- - attrs : tvm.attrs.Attrs - Attributes of current convolution - inputs : list of tvm.relay.Expr - The args of the Relay expr to be legalized - tinfos : list of types - List of input and output types - desired_layouts : list of layout strings - List of layouts defining our desired - layout for the data and kernel inputs respectively. - - Returns - ------- - result : tvm.relay.Expr - The transformed expr - """ - - from tvm import relay - data, weight = inputs - new_attrs = dict(attrs) - - # We expect 2 desired layouts to be specified, one for the data and one for the kernel. - assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" - - # Use the first entry in desired layouts which specifies the data layout. - # The expected ordering of layouts for this operator is defined by this function. - desired_data_layout, desired_kernel_layout = map(str, desired_layouts) - - assert desired_data_layout != "default", "Data layout cannot be default" - - new_attrs['data_layout'] = desired_data_layout - - if desired_data_layout == 'NCHW': - if desired_kernel_layout != 'default': - new_attrs['kernel_layout'] = desired_kernel_layout - else: - new_attrs['kernel_layout'] = 'OIHW' - # Actual insertion of layout transforms is taken care internally - # by ConvertLayout pass. - return relay.nn.conv2d(data, weight, **new_attrs) - - raise ValueError('Layout %s is not yet supported' % desired_data_layout) - - -**FInferCorrectLayout - Layout inference** - Currently, this attribute is exposed only in C++. This function takes original input layouts and the new input layouts (passed from the previous operator or from the python callback for layout alteration), and infers the final data layouts. Layout inference is called for each operator. The usage might vary for different operator categories. For layout agnostic operators, we just want to return the new data layouts in this function. For lightly-layout and heavily-layout sensitive operators, we can change the operator attributes (like axis for concatenate, pad_width for pad) so that we can adapt to the new data layout, preventing insertion of layout transforms. Let's look at a couple of examples to understand this better. - -First example is for layout agnostic operators. These operators do not have any operator attributes that are affected by data layouts, so we just adapt to new layouts. - -.. code-block:: c++ - - // For operator set its attributes like following - // .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - - // Take arbitrary input layouts and copy to outputs. - inline Array> ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array> &old_in_shapes) { - Layout ret; - - if (new_in_layouts.defined()) { - ICHECK_GE(new_in_layouts.size(), 1); - ret = new_in_layouts[0]; - } else { - for (size_t i = 0; i < old_in_layouts.size(); ++i) { - if (old_in_layouts[i].defined()) { - ret = old_in_layouts[i]; - break; - } - } - } - - return Array>{Array(old_in_layouts.size(), ret), {ret}}; - } - - -Second example is for a lightly-layout sensitive operator - batch normalization. BatchNorm has an axis operator that has to change when we go from NHWC to NCHW data layout. (Similar handling also needs to be for heavily-layout sensitive operators) - - -.. code-block:: c++ - - Array> BatchNormInferCorrectLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array>& old_in_shapes) { - BatchNormAttrs* param = const_cast(attrs.as()); - - size_t axis = - param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); - - Layout ret = Layout::Undef(); - - // For example, consider old_layout = NHWC, and new_layout = NCHW, and param->axis = 3 - - if (new_in_layouts.defined() && old_in_layouts.defined()) { - // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout. - - // Following line gives bn_dim = C as old_layout = NHWC, axis = 3 - const auto& bn_dim = old_in_layouts[0][axis]; - - // The new_index is 1 because new_layout = NCHW and bn_dim is C - auto new_index = new_in_layouts[0].IndexOf(bn_dim); - - // We modify the layout-dependent attribute here - axis to 1. - param->axis = new_index; - - // Finally, we adapt to the new layout. - ret = new_in_layouts[0]; - - } else if (old_in_layouts.defined()) { - ret = old_in_layouts[0]; - } - - // In case both new and old layouts are undefined, then there is no need of a change. - // ConvertLayout pass skips the automatic insertion of layout transforms in this case. - - // Following line is not important to tutorial. But, layout inference needs to define - // the layout for all input and output data layouts. For batch norm, the other inputs - // and outputs are vector having length of C dim in the input. So, we set the other - // layouts as C. BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs - // have "C" layout. - Layout c_layout = Layout("C"); - - return Array>{{ret, c_layout, c_layout, c_layout, c_layout}, - {ret, c_layout, c_layout}}; - } - - -******** -4. Usage -******** -.. _convert-layout-usage: - -ConvertLayout pass is extremely easy to use. The pass is not a part of default relay.build pipeline. The intended usage is to call it between the framework-to-relay parser and relay.build module call. - -In order to specify the layouts to convert to, we create a mapping of heavily-layout sensitive operators to a list of the desired layouts for that operator. The first example below specifies data layout, we allow the kernel layout to be automatically converted to one that is supported by TVM (for that particular data layout and operator). This is specified by the use of the "default" keyword. The second example shows how we could have also converted to a specific kernel layout of our choosing. It's worth noting that the following examples will convert to the same layouts i.e. `{'nn.conv2d': ['NCHW', 'default']} == {'nn.conv2d': ['NCHW', 'OIHW']}` - -.. code-block:: python - - # TFlite framework to Relay parser - Default layout is NHWC - mod, params = relay.frontend.from_tflite(tflite_model, - shape_dict=shape_dict, - dtype_dict=dtype_dict) - - # We assume our model's heavily-layout sensitive operators only consist of nn.conv2d - desired_layouts = {'nn.conv2d': ['NCHW', 'default']} - - # Convert the layout to NCHW - # RemoveUnunsedFunctions is used to clean up the graph. - seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), - relay.transform.ConvertLayout(desired_layouts)]) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - - # Call relay compilation - with relay.build_config(opt_level=3): - graph, lib, params = relay.build(mod, target, params=params) - - -.. code-block:: python - - desired_layouts = {'nn.conv2d': ['NCHW', 'OIHW']} - pass = relay.transform.ConvertLayout(desired_layouts) - - -The ordering of the layouts is defined by the implementation of `register_convert_op_layout("OPNAME")`, you can refer to the docstring which should explicitly state the expected layout. In the examples above it's [data_layout, kernel_layout]. - -Current implementation has support for almost all the operators commonly used in image classification models. However, if one encounters too many data layout transforms in the graph, it is highly likely that there is an operator whose layouts need special handling as described in Section 3. Some pull requests that can help in such a situation are - -- Layout inference for `Batch Norm `_ - Batch normalization falls into the category of lightly-sensitive operator. The PR shows how to handle the layout inference for batch norm. -- Python Callback for `Convolution `_- For highly-sensitive operators, one might have to do python callback as well. The PR shows how to define a python callback function for Convolution operator. diff --git a/docs/arch/frontend/tensorflow.rst b/docs/arch/frontend/tensorflow.rst deleted file mode 100644 index dde7179d90db..000000000000 --- a/docs/arch/frontend/tensorflow.rst +++ /dev/null @@ -1,254 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -TensorFlow Frontend -=================== - -The TensorFlow frontend helps in importing TensorFlow models into TVM. - -Supported versions: - -- 1.12 and below - -Tested models: - -- Inception (V1/V2/V3/V4) -- Resnet (All) -- Mobilenet (V1/V2 All) -- Vgg (16/19) -- BERT (Base/3-layer) - -Preparing a Model for Inference -------------------------------- - -Remove Unneeded Nodes -~~~~~~~~~~~~~~~~~~~~~ - -The export process will remove many nodes that are not needed for inference, but unfortunately will leave some remaining. The nodes that should be manually removed are: - -- Dropout, including `Dropout`_ and `DropoutWrapper`_ -- `Assert`_ - -.. _Dropout: https://www.tensorflow.org/api_docs/python/tf/nn/dropout -.. _DropoutWrapper: https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/nn/rnn_cell/DropoutWrapper?hl=hr -.. _Assert: https://www.tensorflow.org/api_docs/python/tf/debugging/Assert - -Convert None Dimensions to Constants -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TVM has minimal support for dynamic tensor shapes. Dimensions that are ``None`` should be replaced with constants. For example, a model may accept an input with shape ``(None,20)``. This should be converted to a shape like ``(1,20)``. The model should be modified accordingly to ensure that these shapes match throughout the graph. - -Export -~~~~~~ - -TensorFlow frontend expects a frozen protobuf (.pb) or saved model as input. It currently does not support checkpoint (.ckpt). The graphdef needed by the TensorFlow frontend can be extracted from the active session, or by using the `TFParser`_ helper class. - -.. _TFParser: https://github.com/apache/tvm/blob/main/python/tvm/relay/frontend/tensorflow_parser.py - -The model should be exported with a number of transformations to prepare the model for inference. It is also important to set ```add_shapes=True```, as this will embed the output shapes of each node into the graph. Here is one function to export a model as a protobuf given a session: - -.. code:: python - - import tensorflow as tf - from tensorflow.tools.graph_transforms import TransformGraph - - def export_pb(session): - with tf.gfile.GFile("myexportedmodel.pb", "wb") as f: - inputs = ["myinput1", "myinput2"] # replace with your input names - outputs = ["myoutput1"] # replace with your output names - graph_def = session.graph.as_graph_def(add_shapes=True) - graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs) - graph_def = TransformGraph( - graph_def, - inputs, - outputs, - [ - "remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)", - "sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering - "remove_attribute(attribute_name=_XlaSeparateCompiledGradients)", - "remove_attribute(attribute_name=_XlaCompile)", - "remove_attribute(attribute_name=_XlaScope)", - "sort_by_execution_order", - "remove_device", - "sort_by_execution_order", - "fold_batch_norms", - "sort_by_execution_order", - "fold_old_batch_norms", - "sort_by_execution_order" - ] - ) - f.write(graph_def.SerializeToString()) - -Another method is to `export and freeze the graph `_. - -Import the Model ----------------- - -Explicit Shape: -~~~~~~~~~~~~~~~ - -To ensure shapes can be known throughout the entire graph, pass the ```shape``` argument to ```from_tensorflow```. This dictionary maps input names to input shapes. Please refer to these `test cases `_ for examples. - -Data Layout -~~~~~~~~~~~ - -Most TensorFlow models are released with NHWC layout. NCHW layout often provides better performance, especially on GPU. The TensorFlow frontend can automatically convert the model's data layout by passing the argument ```layout='NCHW'``` to ```from_tensorflow```. - -Best Practices --------------- - -- Use static tensor shapes instead of dynamic shapes (remove ```None``` dimensions). -- Use static RNN instead of dynamic RNN, as ```TensorArray``` isn't supported yet. - -Supported Ops -------------- - -- Abs -- Add -- AddN -- All -- Any -- ArgMax -- ArgMin -- AvgPool -- BatchMatMul -- BatchMatMulV2 -- BatchNormWithGlobalNormalization -- BatchToSpaceND -- BiasAdd -- BroadcastTo -- Cast -- Ceil -- CheckNumerics -- ClipByValue -- Concat -- ConcatV2 -- Conv2D -- Cos -- Tan -- CropAndResize -- DecodeJpeg -- DepthwiseConv2dNative -- DepthToSpace -- Dilation2D -- Equal -- Elu -- Enter -- Erf -- Exit -- Exp -- ExpandDims -- Fill -- Floor -- FloorDiv -- FloorMod -- FusedBatchNorm -- FusedBatchNormV2 -- Gather -- GatherNd -- GatherV2 -- Greater -- GreaterEqual -- Identity -- IsFinite -- IsInf -- IsNan -- LeakyRelu -- LeftShift -- Less -- LessEqual -- Log -- Log1p -- LoopCond -- LogicalAnd -- LogicalOr -- LogicalNot -- LogSoftmax -- LRN -- LSTMBlockCell -- MatMul -- Max -- MaxPool -- Maximum -- Mean -- Merge -- Min -- Minimum -- MirrorPad -- Mod -- Mul -- Neg -- NextIteration -- NotEqual -- OneHot -- Pack -- Pad -- PadV2 -- Pow -- Prod -- Range -- Rank -- RealDiv -- Relu -- Relu6 -- Reshape -- ResizeBilinear -- ResizeBicubic -- ResizeNearestNeighbor -- ReverseV2 -- RightShift -- Round -- Rsqrt -- Select -- Selu -- Shape -- Sigmoid -- Sign -- Sin -- Size -- Slice -- Softmax -- Softplus -- SpaceToBatchND -- SpaceToDepth, -- Split -- SplitV -- Sqrt -- Square -- SquareDifference -- Squeeze -- StridedSlice -- Sub -- Sum -- Switch -- Tanh -- TensorArrayV3 -- TensorArrayScatterV3 -- TensorArrayGatherV3 -- TensorArraySizeV3 -- TensorArrayWriteV3 -- TensorArrayReadV3 -- TensorArraySplitV3 -- TensorArrayConcatV3 -- Tile -- TopKV2 -- Transpose -- TruncateMod -- Unpack -- UnravelIndex -- Where -- ZerosLike diff --git a/docs/arch/hybrid_script.rst b/docs/arch/hybrid_script.rst deleted file mode 100644 index a4fce342f728..000000000000 --- a/docs/arch/hybrid_script.rst +++ /dev/null @@ -1,100 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Hybrid Frontend Developer Guide -=============================== - -If you are a developer: - -1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe :ref:`hybrid-langref-label` is a better place for you. - -2. who wants to know the implementation details of this module, you are right here! - -Features --------- - -Software Emulation -~~~~~~~~~~~~~~~~~~ - -In software emulation, the most interesting thing is the decorator ``tvm.te.hybrid.script``. -This decorator helps 2 things: - -1. Importing runtime variables - -2. Overloading the function according to the arguments passed - -Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no -choice. What I did is to add those names into python dict ``func.__global__`` and after -the call to ``func`` is done, those names will be cleaned up. - -Overload is simple: the decorator checks the arguments' types and determines which function -should be actually called. - - -Backend Compilation -~~~~~~~~~~~~~~~~~~~ - -Compilation is a large module, you can see ``python/tvm/te/hybrid/`` for more -details. The first stage determines the usage, or more accurately the -declaration of each variable and the second stage does the actual IR -generation. - -Attributes -~~~~~~~~~~ - -So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript`` -in ``python/tvm/te/hybrid/parser.py`` for more details. This is a hacky solution, I just -check the attributes when subscript. - -Loops -~~~~~ - -In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. - - -.. note:: - - Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` - is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it - to HalideIR, we need to do ``start, extent = a, b - a`` - - -.. note:: - - In HalideIR those are enums, they are in passive form. - Here we use active form to annotate loops, because they are ready to run. - - -Variables -~~~~~~~~~ - -Because there is no variables in ``HalideIR``, all the mutable variables will be lowered to an array with size 1. -It takes the first store of a variable as its declaration. - -Math Intrinsics -~~~~~~~~~~~~~~~ -So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. -Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation -except ``popcount`` and ``sigmoid``. I implemented them manually. - - -Casting -~~~~~~~ - -You can cast values by using the keywords ``uint8``, ``uint16`` ``uint32``, ``uint64``, ``int8``, ``int16``, ``int32``, ``int64``, -``float16``, ``float32``, ``float64``. diff --git a/docs/arch/index.rst b/docs/arch/index.rst index 17884a774253..cf4829268ee2 100644 --- a/docs/arch/index.rst +++ b/docs/arch/index.rst @@ -18,46 +18,37 @@ Design and Architecture ======================= -This document is intended for developers who want to understand the -architecture of TVM and/or actively develop on the project. +This document is intended for developers who want to understand the architecture of Apache TVM and/or actively develop on the project. This page is organized as follows: -- The `Example Compilation Flow`_ gives an overview of the steps that TVM takes to turn a high level description of a model into a deployable module. +- The `Overall Flow`_ gives an overview of the steps that TVM takes to turn a high level description of a model into a deployable module. To get started, please read this section first. - -- The `Logical Architecture Components`_ section describes the logical components. - The sections after are specific guides focused on each logical component, organized - by the component's name. - -- The :ref:`Device/Target Interactions ` - page describes how TVM interacts with each supported physical device - and code-generation target. - -- Feel free to also check out the :ref:`dev-how-to` for useful development tips. +- Brief introduction to the key components of the TVM stack. Feel free to also check out the :ref:`TensorIR Deep Dive ` + and :ref:`Relax Deep Dive ` for more details about the two major components in the TVM stack. This guide provides a few complementary views of the architecture. First, we review a single end-to-end compilation flow and discuss the key data structures and the transformations. This runtime-based view focuses on the interactions of each components when running the compiler. Then we will review the logical modules of the codebase and their relationship. This part provides a static overarching view of the design. - -Example Compilation Flow ------------------------- +Overall Flow +------------ In this guide, we will study an example compilation flow in the compiler. The figure below shows the flow. At a high-level, it contains several steps: -- Import: The frontend component ingests a model into an IRModule, which contains a collection of functions that internally represent the model. -- Transformation: The compiler transforms an IRModule to another functionally equivalent or approximately +- **Model Creation**: Create the IRModule to be optimized and compiled, which contains a collection of functions that internally represent the model. + Users can manually construct IRModule via NNModule, TVMScript, or import a pre-trained model from from Relax frontend. +- **Transformation**: The compiler transforms an IRModule to another functionally equivalent or approximately equivalent(e.g. in the case of quantization) IRModule. Many of the transformations are target (backend) independent. We also allow target to affect the configuration of the transformation pipeline. -- Target Translation: The compiler translates(codegen) the IRModule to an executable format specified by the target. +- **Target Translation**: The compiler translates(codegen) the IRModule to an executable format specified by the target. The target translation result is encapsulated as a `runtime.Module` that can be exported, loaded, and executed on the target runtime environment. -- Runtime Execution: the user loads back a `runtime.Module` and runs the compiled functions in the supported runtime environment. +- **Runtime Execution**: the user loads back a `runtime.Module` and runs the compiled functions in the supported runtime environment. -.. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_dyn_workflow.svg +.. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg :align: center - :width: 85% + :width: 80% Key data structures @@ -70,13 +61,14 @@ components that either define a collection of key data structures or transformat **IRModule** is the primary data structure used across the entire stack. An IRModule (intermediate representation module) contains a collection of functions. Currently, we support two primary variants of functions. -- **relay::Function** is a high-level functional program representation. A relay.Function usually corresponds to an end-to-end model. - You can view a relay.Function as a computational graph with additional support for control-flow, recursion, and complex data structures. +- **relax::Function** is a high-level functional program representation. A relax.Function represents high-level graph structure, + usually corresponds to an end-to-end model or a sub-graph of the overall model. You can view a relax.Function as a computational + graph with additional support for control-flow, and complex data structures. - **tir::PrimFunc** is a low-level program representation that contains elements including loop-nest choices, multi-dimensional load/store, threading, and vector/tensor instructions. It is usually used to represent an operator program that executes a (possibly-fused) layer in a model. -During the compilation, a relay function may be lowered to multiple tir::PrimFunc functions and a top-level function that calls into -those tir::PrimFunc functions. +During the compilation and transformation, all relax operators are lowered to ``tir::PrimFunc`` or ``TVM PackedFunc``, which can be executed directly +on the target device, while the calls to relax operators are lowered to calls to low-level functions (e.g. ``R.call_tir`` or ``R.call_dps``). Transformations ~~~~~~~~~~~~~~~ @@ -86,44 +78,35 @@ Now that we have covered the key data structures, let us talk about the transfor - optimization: transform a program to an equivalent, possibly more optimized version. - lowering: transform a program to a lower-level representation that is closer to the target. -**relay/transform** contains a collection of passes that optimize the model. The optimizations include common program -optimizations such as constant folding and dead-code elimination, and tensor-computation specific passes such as layout -transformation and scaling factor folding. - -Near the end of the relay optimization pipeline, we will run a pass(FuseOps) to break the end-to-end function(e.g. MobileNet) -into sub-function(e.g. conv2d-relu) segments. We call these segments of functions. -This process helps us to divide the original problem into two sub-problems: - -- Compilation and optimization for each sub-function. -- Overall execution structure: we need to do a sequence of calls into the generated sub-functions to execute the whole model. - -We use the low-level tir phase to compile and optimize each sub-functions. For specific targets, we may also directly go to the target translation -phase and use external code generators. - -There are a few different ways(in relay/backend) to handle the calls into the overall execution problem. For simple models with known shapes and no control flow, we can lower to a graph executor that stores the execution structure in a graph. We also support a virtual machine backend for dynamic executions. Finally, we plan to support ahead of time compilation that compiles the high-level execution structure into the executable and generated primitive functions. All of these execution modes are encapsulated by a unified **runtime.Module** interface, which we will discuss in the latter part of the guide. +relax transformations +^^^^^^^^^^^^^^^^^^^^^ +relax transformations contain a collection of passes that apply to relax functions. The optimizations include common graph-level +optimizations such as constant folding and dead-code elimination for operators, and backend-specific optimizations such as library dispatch. -**tir/transform** contains transformation passes for TIR level functions. Many tir passes serve the purpose of lowering. For example, there are passes to flatten multi-dimensional access to one-dimensional pointer access, to expand the intrinsics into target-specific ones, and to decorate the function entry to meet the runtime calling convention. Of course, there are also optimizations passes, such as access index simplification and dead code elimination. +tir transformations +^^^^^^^^^^^^^^^^^^^ +tir transformations contain a collection of passes that apply to tir functions. There are two major types of transformations: -Many low-level optimizations can be handled in the target phase by the LLVM, CUDA C, and other target compilers. As a result, we leave low-level optimizations such as register allocation to the downstream compilers and only focus on optimizations that are not covered by them. +- **TensorIR schedule**: TensorIR schedules are designed to optimize the TensorIR functions for a specific target, with user-guided instructions and control how the target code is generated. + For CPU targets, TIR PrimFunc can generate valid code and execute on the target device without schedule but with very-low performance. However, for GPU targets, the schedule is essential + for generating valid code with thread bindings. For more details, please refer to the :ref:`TensorIR Transformation ` section. Additionally, we provides ``MetaSchedule`` to + automate the search of TensorIR schedule. +- **Lowering Passes**: These passes usually perform after the schedule is applied, transforming a TIR PrimFunc into another functionally equivalent PrimFunc, but closer to the + target-specific representation. For example, there are passes to flatten multi-dimensional access to one-dimensional pointer access, to expand the intrinsics into target-specific ones, + and to decorate the function entry to meet the runtime calling convention. -Search-space and Learning-based Transformations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Many low-level optimizations can be handled in the target phase by the LLVM, CUDA C, and other target compilers. As a result, we leave low-level optimizations such as register allocation + to the downstream compilers and only focus on optimizations that are not covered by them. -The transformation passes we described so far are deterministic and rule-based. One design goal of the TVM stack is to support high-performance code optimizations for different hardware platforms. To do so, we will need to investigate as many optimization choices as possible, including but not limited to, multi-dimensional tensor access, loop tiling behavior, special accelerator memory hierarchy, and threading. +cross-level transformations +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Apache TVM brings a unity strategy to optimize the end-to-end models. As the IRModule includes both relax and tir functions, the cross-level transformations are designed to mutate +the IRModule by applying different transformations to these two types of functions. -It is hard to define a heuristic to make all of the choices. Instead, we will take a search and learning-based approach. -We first define a collection of actions we can take to transform a program. Example actions include loop transformations, inlining, -vectorization. We call these actions **scheduling primitives**. The collection of scheduling primitives defines a search space of possible -optimizations we can make to a program. The system then searches over different possible scheduling -sequence to pick the best scheduling combination. -The search procedure is usually guided by a machine learning algorithm. - -We can record the best schedule sequence for an (possibly-fused) operator once the search is completed. The compiler can then just lookup the best -schedule sequence and apply it to the program. Notably, this schedule application phase is **exactly like** the rule-based transformations, -enabling us to share the same interface convention with tradition passes. - -We use search based optimizations to handle the initial tir function generation problem. This part of the module is called AutoTVM(auto_scheduler). -We expect to expand the learning-based transformations to more areas as we continue to develop the TVM stack. +For example, ``relax.LegalizeOps`` pass mutates the IRModule by lowering relax operators, add corresponding TIR PrimFunc into the IRModule, and replace the relax operators +with calls to the lowered TIR PrimFunc. Another example is operator fusion pipeline in relax (including ``relax.FuseOps`` and ``relax.FuseTIR``), which fuse multiple consecutive tensor operations +into one. Different from the previous implementations, relax fusion pipeline analyzes the pattern of TIR functions and detects the best fusion rules automatically rather +than human-defined operator fusion patterns. Target Translation ~~~~~~~~~~~~~~~~~~ @@ -204,19 +187,6 @@ except that the data structure of interest changes from the numpy.ndarray to tvm - Manipulate the IR directly using TVM's python API. -Logical Architecture Components -------------------------------- - -.. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_static_overview.svg - :align: center - :width: 85% - - TVM Architecture Diagram - -The above figure shows the major logical components in the project. Please read the following sections -for information about the components and their relations. - - tvm/support ----------- The support module contains the most common utilities for the infrastructure, such as generic arena allocator, socket, and logging. @@ -243,22 +213,19 @@ These hardware-specific runtime modules define APIs for device memory allocation device and benchmark the execution performance. The rpc infrastructure enables data collection from a wide range of hardware backends for learning-based optimizations. - .. toctree:: :maxdepth: 1 runtime - .. toctree:: :maxdepth: 1 debugger - virtual_machine introduction_to_module_serialization device_target_interactions - +.. TODO(tvm-team) add a section about relax vm here tvm/node -------- @@ -275,11 +242,9 @@ Thanks to the node module, we can directly access any field of the TVM's IRNode # we can directly use the field name to access the IR structures assert y.a == x - We can also serialize arbitrary IR node into a JSON format, and load them back. The ability to save/store, and inspect an IR node provides a foundation for making the compiler more accessible. - tvm/ir ------ The `tvm/ir` folder contains the unified data structure and interfaces across for all IR function variants. @@ -331,11 +296,25 @@ in the target and builtin information registered to each target id(cuda, opencl) device_target_interactions +tvm/relax +--------- + +Relax is the high-level IR used to represent the computational graph of a model. Various optimizations are defined in ``relax.transform``. +Note that Relax usually works closely the the TensorIR IRModule, most of the transformations are applied on the both Relax and TensorIR functions +in the IRModule. Please refer to the :ref:`Relax Deep Dive ` for more details. + tvm/tir ------- TIR contains the definition of the low-level program representations. We use `tir::PrimFunc` to represent functions that can be transformed by TIR passes. -Besides the IR data structures, the tir module also defines a set of builtin intrinsics and their attributes via the common Op registry, as well as transformation passes in `tir/transform`. +Besides the IR data structures, the tir module also includes: + +- A set of schedule primitives to control the generated code in ``tir/schedule``. +- A set of builtin intrinsics in ``tir/tensor_intrin``. +- A set of analysis passes to analyze the TIR functions in ``tir/analysis``. +- A set of transformation passes to lower or optimize the TIR functions in ``tir/transform``. + +Please refer to the :ref:`TensorIR Deep Dive ` for more details. tvm/arith --------- @@ -344,75 +323,28 @@ This module is closely tied to the TIR. One of the key problems in the low-level arithmetic properties — the positiveness, variable bound, and the integer set that describes the iterator space. arith module provides a collection of tools that do (primarily integer) analysis. A TIR pass can use these analyses to simplify and optimize the code. -tvm/te ------- - -The name te stands for "tensor expression". This is a domain-specific language module that allows us to construct `tir::PrimFunc` variants quickly by writing tensor expressions. -Importantly, a tensor expression itself is not a self-contained function that can be stored into IRModule. Instead, it is a fragment of IR that we can stitch together to build an IRModule. +tvm/te and tvm/topi +------------------- -`te/schedule` provides a collection of scheduling primitives to control the function being generated. In the future, we might bring some of -these scheduling components to the a `tir::PrimFunc` itself. +TE stands for Tensor Expression. TE is a domain-specific language (DSL) for describing tensor computations. Importantly, a tensor expression +itself is not a self-contained function that can be stored into IRModule. We can use ``te.create_prim_func`` to convert a tensor expression to a ``tir::PrimFunc`` +and then integrate it into the IRModule. -.. toctree:: - :maxdepth: 1 - - inferbound - hybrid_script - -tvm/topi --------- While possible to construct operators directly via TIR or tensor expressions (TE) for each use case it is tedious to do so. -`topi` (Tensor operator inventory) provides a set of pre-defined operators (in TE or TIR) defined by -numpy and found in common deep learning workloads. We also provide a collection of common schedule templates to obtain performant implementations across different target platforms. - - -tvm/relay ---------- -Relay is the high-level functional IR used to represent full models. Various optimizations are defined in `relay.transform`. The Relay compiler defines multiple dialects, -and each dialect is designed to support specific styles of optimization. Notable ones include QNN(for importing pre-quantized models), VM(for lowering to dynamic virtual machine), -memory(for memory optimization). - -.. toctree:: - :maxdepth: 1 - - relay_intro - relay_op_strategy - convert_layout - - -tvm/autotvm ------------ +`topi` (Tensor operator inventory) provides a set of pre-defined operators defined by numpy and found in common deep learning workloads. -AutoTVM and AutoScheduler are both components which automate search based program optimization. This is rapidly evolving and primarily consists of: +tvm/meta_schedule +----------------- -- Cost models and feature extraction. -- A record format for storing program benchmark results for cost model construction. -- A set of search policies over program transformations. +MetaSchedule is a system for automated search-based program optimization. It is designed to be a drop-in replacement for AutoTVM and AutoScheduler, +and can be used to optimize TensorIR schedules. Note that MetaSchedule only works with static-shape workloads. -Automated program optimization is still an active research field. As a result, we have attempted to modularize the design so that researchers may quickly modify a -component or apply their own algorithms via the Python bindings, and -customize the search and plugin their algorithms from the Python binding. - -.. toctree:: - :maxdepth: 1 - - benchmark - -Frontends ---------- -Frontends ingest models from different frameworks into the TVM stack. -:py:mod:`tvm.relay.frontend` is the namespace for model ingestion APIs. - -.. toctree:: - :maxdepth: 1 - - frontend/tensorflow +tvm/dlight +---------- -microTVM --------- -.. toctree:: - :maxdepth: 1 +DLight is a set of pre-defined, easy-to-use, and performant TIR schedules. DLight aims: - microtvm_design - microtvm_project_api - model_library_format +- Fully support **dynamic shape workloads**. +- **Light weight**. DLight schedules provides tuning-free or (very few-shots tuning) schedule with reasonable performance. +- **Robust**. DLight schedules are designed to be robust and general-purpose for a single rule. And if the rule is not applicable, + DLight not raise any error and switch to the next rule automatically. diff --git a/docs/arch/inferbound.rst b/docs/arch/inferbound.rst deleted file mode 100644 index cc516359bdba..000000000000 --- a/docs/arch/inferbound.rst +++ /dev/null @@ -1,763 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _dev-InferBound-Pass: - -******************************************* -InferBound Pass -******************************************* - - -The InferBound pass is run after normalize, and before ScheduleOps `build_module.py `_. The main job of InferBound is to create the bounds map, which specifies a Range for each IterVar in the program. These bounds are then passed to ScheduleOps, where they are used to set the extents of For loops, see `MakeLoopNest `_, and to set the sizes of allocated buffers (`BuildRealize `_), among other uses. - -The output of InferBound is a map from IterVar to Range: - -.. code:: cpp - - Map InferBound(const Schedule& sch); - -Therefore, let's review the Range and IterVar classes: - -.. code:: cpp - - namespace HalideIR { - namespace IR { - class RangeNode : public Node { - public: - Expr min; - Expr extent; - // remainder omitted - }; - }} - - namespace tvm { - class IterVarNode : public Node { - public: - Range dom; - Var var; - // remainder omitted - }; - } - -Note that IterVarNode also contains a Range ``dom``. This ``dom`` may or may not have a meaningful value, depending on when the IterVar was created. For example, when ``tvm.compute`` is called, an `IterVar is created `_ for each axis and reduce axis, with dom's equal to the shape supplied in the call to ``tvm.compute``. - -On the other hand, when ``tvm.split`` is called, `IterVars are created `_ for the inner and outer axes, but these IterVars are not given a meaningful ``dom`` value. - -In any case, the ``dom`` member of an IterVar is never modified during InferBound. However, keep in mind that the ``dom`` member of an IterVar is sometimes used as default value for the Ranges InferBound computes. - -We next review some TVM codebase concepts that are required to understand the InferBound pass. - -Recall that InferBound takes one argument, a Schedule. This schedule object, and its members, contains all information about the program being compiled. - -A TVM schedule is composed of Stages. Each stage has exactly one Operation, e.g., a ComputeOp or a TensorComputeOp. Each operation has a list of root_iter_vars, which in the case of ComputeOp, are composed of the axis IterVars and the reduce axis IterVars. Each operation can also contain many other IterVars, but all of them are related by the operations's list of IterVarRelations. Each IterVarRelation represents either a split, fuse or rebase in the schedule. For example, in the case of split, the IterVarRelation specifies the parent IterVar that was split, and the two children IterVars: inner and outer. - - -.. code:: cpp - - namespace tvm { - class ScheduleNode : public Node { - public: - Array outputs; - Array stages; - Map stage_map; - // remainder omitted - }; - - class StageNode : public Node { - public: - Operation op; - Operation origin_op; - Array all_iter_vars; - Array leaf_iter_vars; - Array relations; - // remainder omitted - }; - - class OperationNode : public Node { - public: - virtual Array root_iter_vars(); - virtual Array InputTensors(); - // remainder omitted - }; - - class ComputeOpNode : public OperationNode { - public: - Array axis; - Array reduce_axis; - Array body; - Array root_iter_vars(); - // remainder omitted - }; - } - -Tensors haven't been mentioned yet, but in the context of TVM, a Tensor represents output of an operation. - -.. code:: cpp - - class TensorNode : public Node { - public: - // The source operation, can be None - // This Tensor is output by this op - Operation op; - // The output index from the source operation - int value_index; - }; - -In the Operation class declaration above, we can see that each operation also has a list of InputTensors. Thus the stages of the schedule form a DAG, where each stage is a node in the graph. There is an edge in the graph from Stage A to Stage B, if the operation of Stage B has an input tensor whose source operation is the op of Stage A. Put simply, there is an edge from A to B, if B consumes a tensor produced by A. See the diagram below. This graph is created at the beginning of InferBound, by a call to `CreateReadGraph `_. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/stage_graph.png - :align: center - -InferBound makes one pass through the graph, visiting each stage exactly once. InferBound starts from the output stages (i.e., the solid blue nodes in the graph above), and moves upwards (in the opposite direction of the edges). This is achieved by performing a reverse topological sort on the nodes of the graph. Therefore, when InferBound visits a stage, each of its consumer stages has already been visited. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/inferbound_traversal.png - :align: center - -The InferBound pass is shown in the following pseudo-code: - -.. code:: cpp - - Map InferBound(const Schedule& sch) { - Array outputs = sch->get_outputs(); - G = CreateGraph(outputs); - stage_list = sch->reverse_topological_sort(G); - Map rmap; - for (Stage s in stage_list) { - InferRootBound(s, &rmap); - PassDownDomain(s, &rmap); - } - return rmap; - } - -The InferBound pass has two interesting properties that are not immediately obvious: - -1. After InferBound visits a stage, the ranges of all IterVars in the stage will be set in ``rmap``. -2. The Range of each IterVar is only set once in ``rmap``, and then never changed. - -So it remains to explain what InferBound does when it visits a stage. As can be seen in the pseudo-code above, InferBound calls two functions on each stage: InferRootBound, and PassDownDomain. The purpose of InferRootBound is to set the Range (in ``rmap``) of each root_iter_var of the stage. (Note: InferRootBound does not set the Range of any other IterVar, only those belonging to root_iter_vars). The purpose of PassDownDomain is to propagate this information to the rest of the stage's IterVars. When PassDownDomain returns, all IterVars of the stage have known Ranges in ``rmap``. - -The remainder of the document dives into the details of InferRootBound and PassDownDomain. Since PassDownDomain is simpler to describe, we will cover it first. - -.. _IterVarHyperGraph: - -IterVar Hyper-graph -------------------- - -The InferBound pass traverses the stage graph, as described above. However, within each stage is another graph, whose nodes are IterVars. InferRootBound and PassDownDomain perform message-passing on these IterVar graphs. - -Recall that all IterVars of the stage are related by IterVarRelations. The IterVarRelations of a stage form a directed acyclic hyper-graph, where each node of the graph corresponds to an IterVar, and each hyper-edge corresponds to an IterVarRelation. We can also represent this hyper-graph as a DAG, which is simpler to visualize as shown below. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/relations.png - :align: center - - -The above diagram shows the IterVar hyper-graph for one stage. The stage has one root_iter_var, ``i``. It has been split, and the resulting inner axis ``i.inner``, has been split again. The leaf_iter_vars of the stage are shown in green: ``i.outer``, ``i.inner.outer``, and ``i.inner.inner``. - -Message passing functions are named "PassUp" or "PassDown", depending on whether messages are passed from children to their parent in the DAG ("PassUp"), or from the parent to its children ("PassDown"). For example, the large arrow on the left-hand side of the diagram above, shows that PassDownDomain sends messages from the root IterVar ``i`` to its children ``i.outer`` and ``i.inner``. - -.. _PassDownDomain: - -PassDownDomain --------------- -The purpose of PassDownDomain is to take the Ranges produced by InferRootBound for the root_iter_vars, and set the Ranges of all other IterVars in the stage. - -PassDownDomain iterates through the stage's IterVarRelations. There are three possible types of IterVarRelation: split, fuse, and rebase. The most interesting case (since it offers opportunity for improvement), is IterVarRelations representing splits. - -The Ranges of the inner and outer IterVars of the split are set based on the parent IterVar's known Range, as follows: - -.. code:: cpp - - rmap[split->inner] = Range::FromMinExtent(0, split->factor) - rmap[split->outer] = Range::FromMinExtent(0, DivCeil(rmap[split->parent]->extent, split->factor)) - -There is an opportunity here to tighten the bounds produced by InferBound, when ``split->factor`` does not evenly divide the parent's extent. Suppose the parent's extent is 20, and the split factor is 16. Then on the second iteration of the outer loop, the inner loop only needs to perform 4 iterations, not 16. If PassDownDomain could set the extent of ``split->inner`` to ``min(split->factor, rmap[split->parent]->extent - (split->outer * split->factor))``, then the extent of the inner variable would properly adapt, based on which iteration of the outer loop is being executed. - -For Fuse relations, the Range of the fused IterVar is set based on the known Ranges of the inner and outer IterVars, as follows: - -.. code:: cpp - - rmap[fuse->fused] = Range::FromMinExtent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent) - - -InferRootBound --------------- - -Recall that InferBound calls InferRootBound, followed by :ref:`PassDownDomain` on each stage in the stage graph. The purpose of InferRootBound is to set the Range of each root_iter_var of the Stage's operation. These Ranges will be propagated to the rest of the stage's IterVars using :ref:`PassDownDomain`. Note that InferRootBound does not set the Range of any other IterVar, only those belonging to the stage's root_iter_vars. - -If the stage is an output stage or placeholder, InferRootBound simply sets the root_iter_var Ranges to their default values. The default Range for a root_iter_var is taken from the ``dom`` member of the IterVar (see the IterVarNode class declaration above). - -Otherwise, InferRootBound iterates through the consumers of the stage. IntSets are created for each of the consumer's IterVars, as follows. Phase 1) IntSets are initialized for the consumer's leaf_iter_vars, and propagated to the consumer's root_iter_vars by PassUpDomain (Phase 2). These IntSets are used to create TensorDom of the input tensors of the consumer stage (Phase 3). Finally, once all of the consumers have been processed, InferRootBound calls GatherBound, to set the Ranges of the stage's root_iter_vars, based on the TensorDoms (Phase 4). - -This process can seem complicated. One reason is that a stage can have more than one consumer. Each consumer has different requirements, and these must somehow be consolidated. Similarly, the stage may output more than one tensor, and each consumer only uses a particular subset of these tensors. Furthermore, even if a consumer uses a particular tensor, it may not use all elements of the tensor. - -As mentioned above, a consumer may only require a small number of elements from each tensor. The consumers can be thought of as making requests to the stage, for certain regions of its output tensors. The job of Phases 1-3 is to establish the regions of each output tensor that are required by each consumer. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/inferbound_phases.png - :align: center - -IntSets -~~~~~~~ - -During InferRootBound, Ranges are converted to IntSets, and message passing is performed over IntSets. Therefore, it is important to understand the difference between Ranges and IntSets. The name "IntSet" suggests it can represent an arbitrary set of integers, e.g., A = \{-10, 0, 10, 12, 13\}. This would certainly be more expressive than a Range, which only represents a set of contiguous integers, e.g., B = \{10,11,12\}. - -However, currently IntSets come in only three varieties: IntervalSets, StrideSets, and ModularSets. IntervalSets, similarly to Ranges, only represent sets of contiguous integers. A StrideSet is defined by a base IntervalSet, a list of strides, and a list of extents. However, StrideSet is unused, and ModularSet is only used by the frontend. - -Therefore, not all sets of integers can be represented by an IntSet in TVM currently. For example, set A in the example above can not be represented by an IntSet. However, in future the functionality of IntSet can be extended to handle more general kinds of integer sets, without requiring modification to users of IntSet. - -*InferBound is more complicated for schedules that contain compute_at. Therefore, we first explain InferBound for schedules that do not contain compute_at.* - -.. _Phase1: - -Phase 1: Initialize IntSets for consumer's leaf_iter_vars -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map rmap: contains the Range for each IterVar of the consumer stage - * Output: Map up_state: contains an IntSet for each leaf_iter_var of the consumer - */ - -In Phase 1, IntSets for each of the consumer's leaf_iter_vars are created, based on the Ranges of the leaf_iter_vars from ``rmap``. Recall that the consumer has already been visited by InferBound, so all of its IterVars have known Ranges in ``rmap``. - -There are three cases: - -- Case 1: Extent of leaf var's Range is 1. In this case, the up_state for the leaf is just a single point, equal to the Range's min. -- Case 2: *No relaxation is needed. In this case, the up_state for the leaf is just a single point, defined by the leaf var itself.* -- Case 3: Relaxation is needed. In this case, the leaf's Range is simply converted to an IntSet. - -For simplicity, we assume the schedule does not contain thread axes. In this case, Case 2 is only relevant if the schedule contains compute_at. Please refer to the section :ref:`InferBoundCA`, for further explanation. - -.. _Phase2: - -Phase 2: Propagate IntSets from consumer's leaves to consumer's roots -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map up_state: consumer leaf -> IntSet - * Output: Map dom_map: consumer root -> IntSet - */ - -The purpose of Phase 2 is to propagate the IntSet information from the consumer's leaf_iter_vars to the consumer's root_iter_vars. The result of Phase 2 is another map, ``dom_map``, that contains an IntSet for each of the consumer's root_iter_vars. - -Phase 2 begins by calling PassUpDomain, which visits the IterVarRelations of the consumer stage. In the case of a Split relation, PassUpDomain sets the up_state of the parent IterVar, based on the inner and outer IntSets, as follows: - -- Case 1: The Ranges of outer and inner IterVars match their ``up_state`` domains. In this case, set the parent's ``up_state`` by simply converting the parent's Range to an IntSet. -- Case 2: *Otherwise, the parent's* ``up_state`` *is defined by evaluating* ``outer*f + inner + rmap[parent]->min``, *with respect to the* ``up_state`` *of outer and inner. Here, instead of using the Split relation's factor, TVM uses* ``f = rmap[inner]->extent``. - -Case 2 is only needed if the schedule contains compute_at. Please refer to the section :ref:`InferBoundCA` below, for further explanation. - -After PassUpDomain has finished propagating up_state to all IterVars of the consumer, a fresh map, from root_iter_vars to IntSet, is created. If the schedule does not contain compute_at, the IntSet for root_iter_var ``iv`` is created by the following code: - -.. code:: cpp - - dom_map[iv->var.get()] = IntSet::range(up_state.at(iv).cover_range(iv->dom)); - -Note that if the schedule does not contain compute_at, Phases 1-2 are actually unnecessary. dom_map can be built directly from the known Ranges in rmap. Ranges simply need to be converted to IntSets, which involves no loss of information. - -.. _Phase3: - -Phase 3: Propagate IntSets to consumer's input tensors -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map dom_map: consumer root -> IntSet - * Output: Map tmap: output tensor -> vector> - */ - -Note that the consumer's input tensors are output tensors of the stage InferBound is working on. So by establishing information about the consumer's input tensors, we actually obtain information about the stage's output tensors too: the consumers require certain regions of these tensors to be computed. This information can then be propagated through the rest of the stage, eventually obtaining Ranges for the stage's root_iter_vars by the end of Phase 4. - -The output of Phase 3 is tmap, which is a map containing all of the stage's output tensors. Recall that a Tensor is multi-dimensional, with a number of different axes. For each output tensor, and each of that tensor's axes, tmap contains a list of IntSets. Each IntSet in the list is a request from a different consumer. - -Phase 3 is accomplished by calling PropBoundToInputs on the consumer. PropBoundToInputs adds IntSets to tmap's lists, for all input Tensors of the consumer. - -The exact behavior of PropBoundToInputs depends on the type of the consumer's operation: ComputeOp, TensorComputeOp, PlaceholderOp, ExternOp, etc. Consider the case of TensorComputeOp. A TensorComputeOp already has a Region for each of its Tensor inputs, defining the slice of the tensor that the operation depends on. For each input tensor i, and dimension j, a request is added to tmap, based on the corresponding dimension in the Region: - -.. code:: cpp - - for (size_t j = 0; j < t.ndim(); ++j) { - // i selects the Tensor t - tmap[i][j].push_back(EvalSet(region[j], dom_map)); - } - -.. _Phase4: - -Phase 4: Consolidate across all consumers -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map tmap: output tensor -> vector> - * Output: Map rmap: rmap is populated for all of the stage's root_iter_vars - */ - -Phase 4 is performed by GatherBound, whose behavior depends on the type of operation of the stage. We discuss the ComputeOp case only, but TensorComputeOp is the same. - -A ComputeOp has only a single output Tensor, whose axes correspond to the axis variables of the ComputeOp. The root_iter_vars of a ComputeOp include these axis variables, as well as the reduce_axis variables. If the root IterVar is an axis var, it corresponds to one of the axes of the output Tensor. GatherBound sets the Range of such a root IterVar to the union of all IntSets (i.e., union of all consumer requests) for the corresponding axis of the tensor. If the root IterVar is a reduce_axis, its Range is just set to its default (i.e., the ``dom`` member of IterVarNode). - -.. code:: cpp - - // 'output' selects the output tensor - // i is the dimension - rmap[axis[i]] = arith::Union(tmap[output][i]).cover_range(axis[i]->dom); - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/gatherbound.png - :align: center - - -The union of IntSets is computed by converting each IntSet to an Interval, and then taking the minimum of all minimums, and the maximum of all of these interval's maximums. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/union.png - :align: center - - -This clearly results in some unnecessary computation, i.e., tensor elements will be computed that are never used. - -Unfortunately, even if we're lucky and the IntervalSet unions do not produce unnecessary computation, the fact that GatherBound considers each dimension of the tensor separately can also cause unnecessary computation. For example, in the diagram below the two consumers A and B require disjoint regions of the 2D tensor: consumer A requires T[0:2, 0:2], and consumer B requires T[2:4, 2:4]. GatherBound operates on each dimension of the tensor separately. For the first dimension of the tensor, GatherBound takes the union of intervals 0:2 and 2:4, producing 0:4 (note that no approximation was required here). Similarly for the second dimension of the tensor. Therefore, the dimension-wise union of these two requests is T[0:4, 0:4]. So GatherBound will cause all 16 elements of tensor T to be computed, even though only half of those elements will ever be used. - - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/gatherbound_problem.png - :align: center - -.. _InferBoundCA: - -InferBound with compute_at --------------------------- - -If the schedule contains compute_at, Phases 1-2 of InferRootBound become more complex. - -Motivation -~~~~~~~~~~ - -**Ex. 1** - -Consider the following snippet of a TVM program: - -:: - - C = tvm.compute((5, 16), lambda i, j : tvm.const(5, "int32"), name='C') - D = tvm.compute((5, 16), lambda i, j : C[i, j]*2, name='D') - -This produces the following (simplified IR): - -:: - - for i 0, 5 - for j 0, 16 - C[i, j] = 5 - for i 0, 5 - for j 0, 16 - D[i, j] = C[i, j]*2 - -It's easy to see that stage D requires all (5,16) elements of C to be computed. - -**Ex. 2** - -However, suppose C is computed at axis j of D: - -:: - - s = tvm.create_schedule(D.op) - s[C].compute_at(s[D], D.op.axis[1]) - -Then only a single element of C is needed at a time: - -:: - - for i 0, 5 - for j 0, 16 - C[0] = 5 - D[i, j] = C[0]*2 - -**Ex. 3** - -Similarly, if C is computed at axis i of D, only a vector of 16 elements of C are needed at a time: - -:: - - for i 0, 5 - for j 0, 16 - C[j] = 5 - for j 0, 16 - D[i, j] = C[j]*2 - -Based on the above examples, it is clear that InferBound should give different answers for stage C depending on where in its consumer D it is "attached". - -.. _AttachPaths: - -Attach Paths -~~~~~~~~~~~~ - -If stage C is computed at axis j of stage D, we say that C is *attached* to axis j of stage D. This is reflected in the Stage object by setting the following three member variables: - -.. code:: cpp - - class StageNode : public Node { - public: - // omitted - - // For compute_at, attach_type = kScope - AttachType attach_type; - - // For compute_at, this is the axis - // passed to compute_at, e.g., D.op.axis[1] - IterVar attach_ivar; - - // The stage passed to compute_at, e.g., D - Stage attach_stage; - - // omitted - }; - -Consider the above examples again. In order for InferBound to determine how many elements of C must be computed, it is important to know whether the computation of C occurs within the scope of a leaf variable of D, or above that scope. For example, in Ex. 1, the computation of C occurs *above* the scopes of all of D's leaf variables. In Ex. 2, the computation of C occurs *within* the scope of all of D's leaf variables. In Ex. 3, C occurs within the scope of D's i, but above the scope of D's j. - -CreateAttachPath is responsible for figuring out which scopes contain a stage C. These scopes are ordered from innermost scope to outermost. Thus for each stage CreateAttachPath produces an "attach path", which lists the scopes containing the stage, from innermost to outermost scope. In Ex. 1, the attach path of C is empty. In Ex. 2, the attach path of C contains {j, i}. In Ex. 3, the attach path of C is {i}. - -The following example clarifies the concept of an attach path, for a more complicated case. - -**Ex. 4** - -:: - - C = tvm.compute((5, 16), lambda i, j : tvm.const(5, "int32"), name='C') - D = tvm.compute((4, 5, 16), lambda di, dj, dk : C[dj, dk]*2, name='D') - s = tvm.create_schedule(D.op) - s[C].compute_at(s[D], D.op.axis[2]) - -Here is the IR after ScheduleOps (note that loops with extent 1 have been preserved, using the ``debug_keep_trivial_loop`` argument of ScheduleOps): - -:: - - realize D([0, 4], [0, 5], [0, 16]) { - produce D { - for (di, 0, 4) { - for (dj, 0, 5) { - for (dk, 0, 16) { - realize C([dj, 1], [dk, 1]) { - produce C { - for (i, 0, 1) { - for (j, 0, 1) { - C((i + dj), (j + dk)) =5 - } - } - } - D(di, dj, dk) =(C(dj, dk)*2) - } - } - } - } - } - } - -In this case, the attach path of C is {dk, dj, di}. Note that C does not use di, but di still appears in C's attach path. - -**Ex. 5** - -Compute_at is commonly applied after splitting, but this can be handled very naturally given the above definitions. In the example below, the attachment point of C is j_inner of D. The attach path of C is {j_inner, j_outer, i}. - -:: - - C = tvm.compute((5, 16), lambda i, j : tvm.const(5, "int32"), name='C') - D = tvm.compute((5, 16), lambda i, j : C[i, j]*2, name='D') - s = tvm.create_schedule(D.op) - d_o, d_i = s[D].split(D.op.axis[1], factor=8) - s[C].compute_at(s[D], d_i) - -The IR in this case looks like: - -:: - - for i 0, 5 - for j_outer 0, 2 - for j_inner 0, 8 - C[0] = 5 - D[i, j_outer*8 + j_inner] = C[0]*2 - -Building an Attach Path -~~~~~~~~~~~~~~~~~~~~~~~ - -We continue to refer to stages C and D, as introduced in the previous section. The CreateAttachPath algorithm builds the attach path of a stage C as follows. If C does not have attach_type ``kScope``, then C has no attachment, and C's attach path is empty. Otherwise, C is attached at attach_stage=D. We iterate through D's leaf variables in top-down order. All leaf variables starting from C.attach_ivar and lower are added to C's attach path. Then, if D is also attached somewhere, e.g., to stage E, the process is repeated for E's leaves. Thus CreateAttachPath continues to add variables to C's attach path until a stage with no attachment is encountered. - -In the example below, C is attached at D, and D is attached at E. - -:: - - C = tvm.compute((5, 16), lambda ci, cj : tvm.const(5, "int32"), name='C') - D = tvm.compute((5, 16), lambda di, dj : C[di, dj]*2, name='D') - E = tvm.compute((5, 16), lambda ei, ej : D[ei, ej]*4, name='E') - s = tvm.create_schedule(E.op) - s[C].compute_at(s[D], D.op.axis[1]) - s[D].compute_at(s[E], E.op.axis[1]) - -With ``debug_keep_trivial_loop=True``, the attach path of C is {dj, di, ej, ei}, and the attach path of D is {ej, ei}: - -:: - - // attr [D] storage_scope = "global" - allocate D[int32 * 1] - // attr [C] storage_scope = "global" - allocate C[int32 * 1] - produce E { - for (ei, 0, 5) { - for (ej, 0, 16) { - produce D { - for (di, 0, 1) { - for (dj, 0, 1) { - produce C { - for (ci, 0, 1) { - for (cj, 0, 1) { - C[(ci + cj)] = 5 - } - } - } - D[(di + dj)] = (C[(di + dj)]*2) - } - } - } - E[((ei*16) + ej)] = (D[0]*4) - } - } - } - -InferBound with compute_at -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Now that the concept of an attach path has been introduced, we return to how InferBound differs if the schedule contains compute_at. The only difference is in InferRootBound, :ref:`Phase1` and :ref:`Phase2`. - -In InferRootBound, the goal is to determine Ranges for the root_iter_vars of a particular stage, C. Phases 1-2 of InferRootBound assign IntSets to the leaf IterVars of C's consumers, and then propagate those IntSets up to the consumers' root_iter_vars. - -If there are no attachments, the Ranges already computed for the consumer's variables define how much of C is needed by the consumer. However, if the stage is actually inside the scope of one of the consumer's variables j, then only a single point within the Range of j is needed at a time. - -.. _Phase1CA: - -Phase 1: Initialize IntSets for consumer's leaf_iter_vars -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map rmap: contains the Range for each IterVar of the consumer stage - * Output: Map up_state: contains an IntSet for each leaf_iter_var of the consumer - */ - -In Phase 1, IntSets for each of the consumer's leaf_iter_vars are created, based on the Ranges of the leaf_iter_vars from rmap. Recall that the consumer has already been visited by InferBound, so all of its IterVars have known Ranges in rmap. - -There are three cases: - -- Case 1: Extent of leaf var's Range is 1. In this case, the up_state for the leaf is just a single point, equal to the Range's min. -- Case 2: No relaxation is needed. In this case, the up_state for the leaf is just a single point, defined by the leaf var itself. -- Case 3: Relaxation is needed. In this case, the leaf's Range is simply converted to an IntSet. - -Case 2 occurs if we encounter the attachment point of stage C in the consumer. For this attach_ivar, and all higher leaf variables of the consumer, Case 2 will be applied. This ensures that only a single point within the Range of the leaf variable will be requested, if C is inside the leaf variable's scope. - -.. _Phase2CA: - -Phase 2: Propagate IntSets from consumer's leaves to consumer's roots -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: cpp - - /* - * Input: Map up_state: consumer leaf -> IntSet - * Output: Map dom_map: consumer root -> IntSet - */ - -Phase 2 begins by calling PassUpDomain, which visits the IterVarRelations of the consumer stage. In the case of a Split relation, PassUpDomain sets the up_state of the parent IterVar, based on the inner and outer IntSets, as follows: - -- Case 1: The Ranges of outer and inner IterVars match their ``up_state`` domains. In this case, set the parent's ``up_state`` by simply converting the parent's Range to an IntSet. -- Case 2: Otherwise, the parent's ``up_state`` is defined by evaluating ``outer*f + inner + rmap[parent]->min``, with respect to the ``up_state`` of outer and inner. Here, instead of using the Split relation's factor, TVM uses* ``f = rmap[inner]->extent``. - - -Now, because the schedule contains compute_at, it is possible for Case 2 to apply. This is because the leaf IntSets may now be initialized to a single point within their Range (Case 2 of :ref:`Phase1CA`), so the IntSets will no longer always match the Ranges. - -After PassUpDomain has finished propagating up_state to all IterVars of the consumer, a fresh map, from root_iter_vars to IntSet, is created. If the stage is not attached to the current consumer, then for each variable iv in the consumer's attach_path, iv's Range is added to a ``relax_set``. The root variables of the stage are evaluated with respect to this ``relax_set``. - -This is to handle cases like the following example, where C is not attached anywhere, but its consumer D is attached in stage E. In this case, D's attach_path, {ej, ei} must be considered when determining how much of C must be computed. - -:: - - C = tvm.compute((5, 16), lambda ci, cj : tvm.const(5, "int32"), name='C') - D = tvm.compute((5, 16), lambda di, dj : C[di, dj]*2, name='D') - E = tvm.compute((5, 16), lambda ei, ej : D[ei, ej]*4, name='E') - s = tvm.create_schedule(E.op) - s[D].compute_at(s[E], E.op.axis[1]) - - -:: - - for ci 0, 5 - for cj 0, 16 - C[ci, cj] = 5 - for ei 0, 5 - for ej 0, 16 - D[0] = C[ei, ej]*2 - E[ei, ej] = D[0]*4 - -Limitations of PassUpDomain -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This section describes known limitations of PassUpDomain. These limitations affect the Ranges produced by InferBound, as well as other users of PassUpDomain such as ``tensorize``. - -**Ex. 6** - -Above, we discussed the behavior of PassUpDomain on Split relations only. In the following example, the schedule contains ``fuse`` in addition to ``split``. In the TVM program below, the operation C has two axes that are fused, and then the fused axis is split. Note that all tensors are originally of shape ``(4, 4)`` and the fused axis is split by factor ``4`` as well. Therefore, it would be natural to assume that the effect of the fuse is simply undone by the split. However, this is not the case in TVM, as explained below. - -:: - - import tvm - from tvm import te - - n = 4 - m = 4 - - A = te.placeholder((n, m), name='A') - B = te.compute((n, m), lambda bi, bj: A[bi, bj]+2, name='B') - C = te.compute((n, m), lambda ci, cj: B[ci, cj]*3, name='C') - - s = te.create_schedule(C.op) - - fused_axes = s[C].fuse(C.op.axis[0], C.op.axis[1]) - xo, xi = s[C].split(fused_axes, 4) - - s[B].compute_at(s[C], xo) - - print(tvm.lower(s, [A, C], simple_mode=True)) - -The output of this program is shown below. Notice that all 16 elements of B are computed every time through the outer loop, even though C only uses 4 of them. - -:: - - // attr [B] storage_scope = "global" - allocate B[float32 * 16] - produce C { - for (ci.cj.fused.outer, 0, 4) { - produce B { - for (bi, 0, 4) { - for (bj, 0, 4) { - B[((bi*4) + bj)] = (A[((bi*4) + bj)] + 2.000000f) - } - } - } - for (ci.cj.fused.inner, 0, 4) { - C[((ci.cj.fused.outer*4) + ci.cj.fused.inner)] = (B[((ci.cj.fused.outer*4) + ci.cj.fused.inner)]*3.000000f) - } - } - } - -This is in contrast to the following IR, which is produced by modifying the above program by deleting the fuse and split, and replacing the compute_at with ``s[B].compute_at(s[C], C.op.axis[0])``. Note that in the IR below, only 4 elements of B are computed at a time, as desired. The size of buffer B is also smaller. - -:: - - // attr [B] storage_scope = "global" - allocate B[float32 * 4] - produce C { - for (ci, 0, 4) { - produce B { - for (bj, 0, 4) { - B[bj] = (A[((ci*4) + bj)] + 2.000000f) - } - } - for (cj, 0, 4) { - C[((ci*4) + cj)] = (B[cj]*3.000000f) - } - } - } - -This example demonstrates that contrary to what we expect, the split does not simply undo the fuse. So what causes the difference? Why is the entire tensor B re-computed 4 times, when only a single row is actually needed at a time? - -Determining the amount of B that must be computed is the responsibility of InferBound. However, the Ranges returned by InferBound for B's root_iter_vars are too large in this case: ``[0, 4]`` for both ``bi`` and ``bj``. This occurs because of a limitation in PassUpDomain on Fuse relations, which we explain next. - -When InferRootBound is working on stage B, it visits B's consumer stage C to find out how much of B is requested by C. C has root_iter_vars ci and cj, which have been fused and then split. This results in the following :ref:`IterVarHyperGraph` for stage C. - - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/passupdomain_problem.png - :align: center - - - -We trace the execution of InferRootBound on stage B. Recall that :ref:`Phase1CA` of InferRootBound involves setting the IntSets for all leaf_iter_vars of B's consumer stage C. In this case, C's leaf_iter_vars are ``ci.cj.fused.outer`` and ``ci.cj.fused.inner``. Since B is attached at ``ci.cj.fused.outer``, ``ci.cj.fused.inner`` must be relaxed but ``ci.cj.fused.outer`` is a single point. The IntSets of C's leaf_iter_vars, after :ref:`Phase1CA`, are shown in the following table. - -+----------------------+---------------------------------------------------+ -| IterVar | IntSet after Phase 1 | -+======================+===================================================+ -| ``ci.cj.fused.inner``|``[0, (min(4, (16 - (ci.cj.fused.outer*4))) - 1)]``| -+----------------------+---------------------------------------------------+ -| ``ci.cj.fused.outer``| ``[ci.cj.fused.outer, ci.cj.fused.outer]`` | -+----------------------+---------------------------------------------------+ - -In :ref:`Phase2CA` of InferRootBound, PassUpDomain is called on all of C's IterVarRelations in bottom-up order. - -PassUpDomain is called on C's Split node first. Case 2 of PassUpDomain applies, because the IntSet of ``ci.cj.fused.outer`` is just a single point, and doesn't equal its Range (as previously computed by InferBound on stage C). PassUpDomain therefore sets the IntSet of ``ci.cj.fused`` based on the IntSets of ``ci.cj.fused.inner`` and ``ci.cj.fused.outer``, as shown in row 3 of the following table. - -+----------------------+--------------------------------------------------------------------------------------------------+ -| IterVar | IntSet after PassUpDomain on SplitNode | -+======================+==================================================================================================+ -| ``ci.cj.fused.inner``| ``[0, (min(4, (16 - (ci.cj.fused.outer*4))) - 1)]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``ci.cj.fused.outer``| ``[ci.cj.fused.outer, ci.cj.fused.outer]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``ci.cj.fused`` | ``[(ci.cj.fused.outer*4), ((ci.cj.fused.outer*4) + (min(4, (16 - (ci.cj.fused.outer*4))) - 1))]``| -+----------------------+--------------------------------------------------------------------------------------------------+ - -After PassUpDomain is called on the Split node, it is called on the Fuse node. - -- Case 1: the Range of IterVar ``fused`` (i.e., as previously calculated by InferBound) is equal to its IntSet -- Case 2: the IntSet of IterVar ``fused`` is a single point -- Case 3: otherwise - -In our case, the Range of ``ci.cj.fused``, is [0, 16). This is not equal to the IntSet of ``ci.cj.fused``, which has extent at most 4 (see row 3 of the table above). Therefore Case 1 does not apply. Case 2 doesn't apply either, since the IntSet of ``ci.cj.fused`` is not a single point. Therefore, only the default Case 3 applies. - -Unfortunately in Case 3, PassUpDomain conservatively applies a "fallback inference rule", i.e., it just returns IntSets equal to the Ranges of ``ci`` and ``cj``. Since C is the output stage of the schedule, we know that InferBound will have set the Ranges of the root_iter_vars of C (i.e., ``ci`` and ``cj``) to their original dimensions (i.e., the ``dom`` value of their IterVars). The resulting output of PassUpDomain for ``ci`` and ``cj`` is shown in the last two rows of the table below. - -+----------------------+--------------------------------------------------------------------------------------------------+ -| IterVar | IntSet after PassUpDomain on FuseNode | -+======================+==================================================================================================+ -| ``ci.cj.fused.inner``| ``[0, (min(4, (16 - (ci.cj.fused.outer*4))) - 1)]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``ci.cj.fused.outer``| ``[ci.cj.fused.outer, ci.cj.fused.outer]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``ci.cj.fused`` |``[(ci.cj.fused.outer*4), ((ci.cj.fused.outer*4) + (min(4, (16 - (ci.cj.fused.outer*4))) - 1))]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``ci`` | ``[0, 4]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ -| ``cj`` | ``[0, 4]`` | -+----------------------+--------------------------------------------------------------------------------------------------+ - -This is enough to guarantee that consumer C requests *all* elements of B: the IntSets of ``ci`` and ``cj`` become requests from consumer C to the output tensors of stage B (via PropBoundToInputs in :ref:`Phase3` and GatherBound in :ref:`Phase4`). - -This example shows that schedules containing a split of fused axes are difficult to handle in TVM. The source of the difficulty is similar to the limitations of GatherBound. The region of tensor B requested by a consumer C must be a single rectangular region of B. Or, if B has more than two dimensions, the region of B must be expressible as an independent Range for each of its axes. - -If the split factor is 4, or 8, in the above example, the region of B needed in each iteration of the outer loop is rectangular. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/passupdomain_div.png - :align: center - -However, if the split factor is changed from 4 to 3 in the example above, it is easy to see that the region of B that C needs can no longer be described by an independent Range for each of its axes. - - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/passupdomain_nodiv.png - :align: center - -The best that can be done with rectangular regions is shown in the following diagram. The orange regions are the minimum rectangular regions covering the region of B that needs to be computed, at each iteration of the outer loop. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/docs/inferbound/passupdomain_min.png - :align: center diff --git a/docs/arch/microtvm_design.rst b/docs/arch/microtvm_design.rst deleted file mode 100644 index f9c06c10b677..000000000000 --- a/docs/arch/microtvm_design.rst +++ /dev/null @@ -1,357 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at -.. http://www.apache.org/licenses/LICENSE-2.0 -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _microtvm-design: - -************************** -microTVM Design Document -************************** - -.. contents:: Table of Contents - :depth: 3 - -Background -=========== - -TVM is a model deployment framework that has demonstrated good performance across a wide range of -models on traditional operating systems. Given TVM's layered approach to compilation, it is a -natural extension to target bare metal devices. While most of the compilation flow does not need to -change for a proof-of-concept implementation on such devices, the runtime cannot depend on: - -* **Virtual Memory**, and by extension any system-provided ``malloc``. Additionally, bare metal - devices typically have very limited memory (measured in KB). Because of this, libraries designed - for such platforms typically need to be more judicious in using memory, and need to release - memory when it is not in use. -* Traditional OS abstractions, such as **files**, **libraries**, and **kernel functions**. Some - projects implement support for these, but they are by no means standard. -* Support for programming languages other than **C**. - -Such changes require a different approach from the TVM C++ runtime typically used on traditional -Operating Systems. - -Typical Use -=========== - -This section discusses our vision of the "typical" microTVM use case. Each component used to achieve -this typical use case is intended to be designed for flexibility, but this unifying vision serves to -motivate the inclusion of each part of the design. - -.. figure:: https://raw.githubusercontent.com/tvmai/web-data/main/images/dev/microtvm_workflow.svg - :align: center - :width: 85% - -The parts of this process are described below: - -#. **Model Import**. The user imports an existing model or describes a new model to TVM, producing a - *Relay module*. - -#. **Model Transformations**. The user can apply transformations, such as quantization, to the - model. After each transformation, the user should still have a Relay module. - -#. **Compilation** (Scheduling and Code Generation). TVM implements each operator into Tensor IR by - assigning a schedule and schedule configuration to each Relay operator. Then, code (C source or - compiled object) is generated for each operator. - -#. **Integration**. The generated code is integrated along with the TVM C Runtime library into a - user-supplied binary project. In some cases (such as when the project is standardized across - multiple SoC/development boards), this process is handled automatically. - -#. **Deployment**. The project is built and the residual firmware binary is flashed onto the device. - Model inference is driven either by TVM using an on-device RPC server, or on the device using the - on-device Graph Executor. - -Design Goals -============ - -microTVM aims to achieve these design goals: - -1. **Portable Code**. microTVM can translate any Relay model into C code that can compile with only - a C standard library. -2. **Minimal Overhead**. microTVM generates target-specific, highly optimized code. As much overhead - from the runtime should be removed. -3. **Accessible Code**. microTVM considers C source code as a first-class output mechanism so that - it is easier for a firmware engineer to understand and tweak. - -Overview -======== - -microTVM requires changes at all levels of the TVM compiler stack. The following sub-sections enumerate -these changes at a high level, and follow-on sections discuss the specifics in more detail. - -Modeling Target Platforms -------------------------- - -TVM's search-based optimization approach allows it to largely avoid system-level modeling of targets -in favor of experimental results. However, some modeling is necessary in order to ensure TVM is -comparing apples-to-apples search results, and to avoid wasting time during the search by attempting -to compile invalid code for a target. - -microTVM models these parts of the target: - -* The CPU used, through the ``-mcpu`` and ``-march`` target flags. -* The presence or absence of accelerators, through the device components of the target (Currently - only the absence of accelerators can be expressed, but this mechanism should extend well). - -microTVM aims to model these parts of the target in the future: - -* Memory, modeled as a set of disjoint memory spaces, each with a label and size and prefetch/flush - behavior. Some memory may be shared with accelerators. -* Target runtime configuration (i.e. clock tree configuration, clock speed, etc). This is intended - only to contribute to the AutoTVM schedule key and not for any other use. - -At this time, TVM does not intend to model: - -* Size, type, or relationship of caches, with the exception of prefetching or cache flushing. - - -TVM Targets for microTVM -------------------------- - -A central data structure in the compilation process is the ``tvm::target::Target`` class. TVM uses -Target to decide which TIR schedules to enable and how to configure the code generator. The Target -class should also uniquely identify the generated code for a particular operator, as autotuning -logs use it to rank measured performance (but see Future Work). - -Targets are currently represented as strings structured similarly to command-line arguments. An -example target is shown below: - - ``c -keys=arm_cpu -mcpu=cortex-m7 -model=stm32f746xx`` - -The relevant parts to microTVM are: - - * Code generator (``llvm`` or ``c``) - * ``-mcpu=cortex-m7``: used by TOPI to enable Cortex-M schedules, and, when the C source code - generator is selected, included in the output as a comment to help identify the code and - configure the downstream C compiler. - -Runtime and Executor configuration for microTVM ------------------------------------------------ - -When using microTVM, it's important to use the C Runtime (``Runtime('crt')``), which is the runtime that works best on micro devices rather than the more dynamic C++ Runtime. Alongside this, there are two executors which you could use in combination with the C runtime: - -* ``Executor("aot")`` - The Ahead of Time (AOT) executor precompiles the network into a runnable function which you can add directly into your micro application -* ``Executor("graph", {"link-params": True})`` - The Graph executor provides a JSON representation of your network and requires the C Runtime's system library to be generated to find functions in the function registry (``Runtime("crt", {"system-lib": True})``). ``{"link-params":True}`` enables parameters to be linked into the generated files rather than provided externally. - -These are specified when building a runtime module: ``relay.build(..., runtime=..., executor=...)``. - -Writing Schedules for microTVM ------------------------------- - -For operations scheduled on the CPU, microTVM initially plans to make use of specialized -instructions and extern (i.e. hand-optimized) functions to achieve good performance. In TVM, this -approach is generally accomplished through tensorization, in which TVM breaks a computation into -small pieces, and a TIR extern function accelerates each small piece. - -TVM currently accommodates both approaches using ``tir.call_extern``. First, a pragma is attached to -the schedule defining the extern function in portable C. - - ``sched[output].pragma(n, "import_c", "void call_asm(int32_t* a, int32_t* b) { /* ... */ }")`` - -Next, ``tensorize`` is used to split the computation. - - ``sched[output].tensorize(owi, gemm)`` - -There are a couple of caveats to this approach, all which could be resolved by linking generated -code against external libraries: - -* Inline assembly is compiler-specific. While Clang and GCC have standardized on one syntax, this - may not be portable to other compilers. SDKs solve this by conditionally including a header file - depending on the compiler being used. However, taking this approach means that the generated code - needs additional compiler flags (i.e. ``-Isystempath/to/header``). -* It may be helpful to reference helper functions from the generated code (e.g. to inline common - sequences of hand-optimized assembly). -* Finally, the extern function invoked may be wholly written in an external library. If those - functions can be wholly inlined, this caveat is the same as the previous. If not, then additional - C code needs to be compiled and linked against the operator. - -At present, microTVM presumes that all eligible schedules can be compiled. This means that the user- -supplied project (see next section) must include all libraries that are used by the generated code. -When not using autotuning, TVM randomly chooses a fallback schedule, so all libraries would need to -be supported. When using autotuning, TVM selects the best-performing schedule, so only that library -is needed. There isn't currently a way to force TVM to pick a particular schedule outside of -autotuning logs, but that would be a good addition. - -Finally, when using the ``llvm`` backend, the process is similar except that LLVM bitcode is included -in the generated code (with an ``import_llvm`` pragma). LLVM bitcode provides a portable way to call -inline assembly. However, it may be more complex to call external C functions, and helper functions -are of course not easy to use from LLVM bitcode. - -Executing Models ----------------- - -The TVM compiler traditionally outputs three pieces: - -1. Model operator implementations, as discussed above; -2. A model execution graph, encoded as JSON; and -3. Simplified parameters. - -To correctly execute the model, a Graph Executor needs to reconstruct the graph in memory, load the -parameters, and then invoke the operator implementations in the correct order. - -microTVM supports two ways to do this: - -1. **Host-Driven**. The Graph Executor can run on the host and carry out execution by issuing - commands to the device using an RPC link with a UART-like transport. -2. **Standalone**. A C Graph Executor is available to be compiled on-device, but it is not - particularly memory efficient. This way enables standalone execution without any attached host. - -Host-Driven is designed for experimenting with models on-device and, like AutoTVM, uses the RPC server to -drive computation on-device. Standalone is intended for deployment. - -Host-Driven Execution -^^^^^^^^^^^^^^^^^^^^^ - -In Host-Driven execution, the firmware binary is the following: - -1. Generated operator implementations from TVM. -2. The TVM C runtime. -3. SoC-specific initialization. -4. The TVM RPC server. -5. (optional) Simplified Parameters. - -This firmware image is flashed onto the device and a GraphExecutor instance is created on the host. -The GraphExecutor drives execution by sending RPC commands over a UART: - -.. figure:: https://raw.githubusercontent.com/tvmai/web-data/main/images/dev/microtvm_host_driven.svg - :align: center - :width: 85% - -Standalone Execution -^^^^^^^^^^^^^^^^^^^^ - -In Standalone execution, the GraphExecutor is instantiated on device: - -.. figure:: https://raw.githubusercontent.com/tvmai/web-data/main/images/dev/microtvm_standalone.svg - :align: center - :width: 85% - -microTVM Firmware ------------------- - -We can now discuss how microTVM firmware should behave. An important task common to both model -execution strategies is configuring the SoC to match the way it performs in production. microTVM -considers this task project- and SoC-dependent. Whether for AutoTVM, host-driven model inference, or -in standalone deployment, the user is expected to supply a project whose main() does the following: - -1. Configure the SoC to match deployment performance. -2. Initialize the TVM C Runtime. - -When configuring for host-driven inference or AutoTVM, the remaining tasks are well-defined: - -3. Initialize a transport (i.e. a UART) for use with the TVM RPC server. -4. Launch the TVM RPC Server. - -When configuring for standalone deployment, the firmware needs to: - -1. Instantiate the system library by calling the ``runtime.SystemLib`` PackedFunc. -2. Instantiate a GraphExecutor passing the system library module. -3. Configure parameters and inputs as needed. -4. Run the model. - -Parts of a microTVM Binary --------------------------- - -To summarize, a microTVM firwmare binary image must contain these parts: - -1. Operator implementations, produced by TVM. -2. The TVM C runtime library, supplied by TVM as a static library. -3. SoC Initialization, supplied by the user. - -For Host-driven model execution, firmware also needs: - -4. The TVM RPC Server library. - -For Standalone model execution, firmware also needs: - -4. The TVM C GraphExecutor library, supplied by TVM as a static library. -5. The remaining compiler outputs (Simplified Parameters and Graph JSON). - -The Automated Build Flow ------------------------- - -Once code generation is complete, ``tvm.relay.build`` returns a ``tvm.runtime.Module`` and the -user can save the generated C source or binary object to a ``.c`` or ``.o`` file. From this point, TVM -can theoretically step back and the user can compile and run the code separately. - -However, for AutoTVM, TVM needs some automated flow to handle the following tasks: - -1. Integrate operator implementations, the TVM C Runtime library, and the TVM RPC Server library into the - firmware project containing user-supplied SoC Initialization. -2. Build the resulting project. -3. Program the built firmware onto a (specific) attached device. -4. Identify the serial port or other transport to be used by TVM to drive remote execution. - -At present, TVM expects the user to supply an implementation of the ``tvm.micro.Compiler``, -``tvm.micro.Flasher``, and ``tvm.micro.Transport`` interfaces. TVM then: - -1. Builds each piece separately as a library. -2. Builds the libraries into a binary firmware image. -3. Programs the firmware image onto an attached device. -4. Opens a serial port to serve as the RPC server transport. - -This design was chosen to reduce build times for microTVM (the common libraries need to be built -only once per candidate operator implemmentation). In practice, these projects are extremely small -and compile relatively quickly. Compared with the added complexity of this tighter build integration -with TVM, the performance gains are likely not worth it. A future design will consolidate the build -tasks into a single step and narrow the interface to provide a better integration. - -Measuring operator performance ------------------------------- - -The TVM C runtime depends on user-supplied functions to measure time on-device. Users should implement -``TVMPlatformTimerStart`` and ``TVMPlatformTimerStop``. These functions should measure wall clock time, so there -are some pitfalls in implementing these functions: - -1. If the CPU could halt or sleep during a computation (i.e. if it is being done on an accelerator), - a cycle counter should likely not be used as these tend to stop counting while the CPU is asleep. -2. The granularity of these functions can be relaxed as needed to extend the range of the timer - device. However, if granularity is too coarse, a sub-optimal schedule may be used. -3. An error should be raised if the timer overflows. -4. The timer should not interrupt computation unless absolutely necessary. Doing so may affect the - accuracy of the results. -5. Calibrating the output against a wall clock is ideal, but it will likely be too cumbersome. A - future PR could enable some characterization of the platform timer by, e.g., measuring the internal - oscillator against a reference such as an external crystal. - -Future Work -=========== - -Ahead-of-Time Runtime ----------------------- - -A limitation of the Graph Executor is the amount of memory overhead required in parsing the JSON. -The current implementation contributes significantly to the dynamic memory usage of microTVM, -limiting its utility. An ahead-of-time runtime can avoid the need for any Graph JSON parsing and -improve inference speed by generating C code to call the generated operator implementations directly -rather than relying on a data-driven approach with the Graph Executor. - -Memory Planning ----------------- - -The current memory planner attempts to limit the number of ``TVMBackendDeviceAlloc()`` calls -issued for intermediate tensors only. Because scratchpads can vary widely, and because the planner -coalesces memory allocations within 16x of each other, this strategy typically results in high -peak memory usage. - -Heterogeneous Execution ------------------------ - -Newer Cortex-M SoCs can contain multiple CPUs and onboard ML accelerators. - - -Autotuning Target ------------------ - -As discussed previously, diff --git a/docs/arch/microtvm_project_api.rst b/docs/arch/microtvm_project_api.rst deleted file mode 100644 index 381b57876aaa..000000000000 --- a/docs/arch/microtvm_project_api.rst +++ /dev/null @@ -1,150 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _microtvm_project_api: - -microTVM Project API -==================== - -About microTVM Project API --------------------------- - -The microTVM Project API allows TVM to automatically run models on -unconventional or embedded platforms. It allows platforms to define a standard -function to integrate TVM compiler output with boilerplate platform-specific -code, producing a runnable **Project**. Project API then further defines -functions to build that project, program compatible devices accessible from the -TVM machine, and communicate with the running code so that TVM can perform -host-driven inference and autotuning. - -There are many cases where it might be desirable simply to invoke microTVM as a -tool from your platform's build process. Indeed, for the average firmware -developer, this is likely to be all they need. However, there are a couple of -use cases when you may want to teach microTVM how to build firmware using your -platform's build tool: - -1. To enable AutoTVM and AutoScheduling on your platform. Defining a Project - API implementation allows TVM to tune models for peak performance on your - platform. -2. To enable engineers without firmware expertise to experiment with models on - your platform. Defining a Project API implementation allows these engineers - to leverage the standard TVM Python workflows to perform host-driven - inference on your platform. -3. Integration Testing. Defining a Project API implementation allows you to - create Continuous Integration Tests which verify model correctness and - performance on your platform. - -API Definition --------------- - -The full API is the ``abstractmethod`` defined on ``ProjectAPIHandler`` in -`python/tvm/micro/project_api/server.py `_. -Rather than duplicate the documentation here, we simply refer you to that class. - -How TVM uses Project API ------------------------- - -This section explains how the Project API should be used with TVM. Project API -is defined around the *Project* as the buildable unit of firmware. TVM expects -to be provided initially with a directory containing a *Template Project*, which -together with a :ref:`Model Library Format ` file can be -built into a runnable project. - -Inside the Template Directory is (typically) a Python script implementing the -API server. TVM launches this script in a subprocess and sends commands to the -server to perform each of the actions outlined above. - -The typical usage flow is as follows: - -1. Launch Project API server in Template Project. -2. Verify the API server is version-compatible with TVM, plus read properties - of the implementation, by sending ``server_info_query`` command. -3. Generate a new project by sending command ``generate_project`` to create a - new project. The arguments to this command is a Model Library Format and a - non-existent directory which should be populated with the generated - project. The Template Project API server should copy itself into the - newly-generated project. -4. Terminate the Template Project API server. -5. Launch Project API server in Generated Project. -6. Verify the API server is version-compatible with TVM, plus read properties - of the implementation, by sending ``server_info_query`` command. -7. Build and flash the projec by sending commands ``build`` and ``flash`` to the - API server. -8. Communicate with the target. Send command ``open_transport`` followed by - commands ``write_transport`` and ``read_transport`` to write and read from - e.g. a serial port attached to the target. Upon completion, - ``close_transport`` is sent. -9. Terminate Project API server. - -Disk Layout of the Project --------------------------- - -In the root directory of a project (template or generated), one of the following -two files must exist: - -- ``microtvm_api_server.py`` - the suggested approach. Place a - python3-compatible Python script in the root directory. TVM will execute this - script in its own process using the same interpreter used to execute TVM. -- ``microtvm_api_server.sh`` (on Windows, ``microtvm_api_server.bat``) - - alternate approach. When a different Python interpreter is necessary, or - when you want to implement the server in a different language, create this - executable file. TVM will launch this file in a separate process. - -Aside from these two files, no other restrictions are made on the layout. - -Communication between TVM and Project API Server ------------------------------------------------- - -TVM communicates with the Project API server using `JSON-RPC 2.0 -`_. TVM always launches API servers using -the following command-line: - -``microtvm_api_server.py --read-fd --write-fd `` - -Commands are sent from TVM to the server over the file descriptor given by -``--read-fd`` and replies are received by TVM from the server over the file -descriptor given by ``--write-fd``. - -Helpers for Implementing the API server in Python -------------------------------------------------- - -TVM provides helper utilities that make it easy to implement the server in Python. -To implement the server in Python, create ``microtvm_api_server.py`` and add -``from tvm.micro.project_api import server`` (or, copy this file into your template -project--there are no dependencies--and import it there instead). Next, subclass -``ProjectAPIHander``:: - - class Handler(server.ProjectAPIHandler): - def server_info_query(self, tvm_version): - # Implement server_info_query - - def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): - # Implement generate_project - - # ... - -Finally, invoke the helper ``main()``:: - - if __name__ == "__main__": - server.main(Handler()) - -Using Project API from ``tvmc`` -------------------------------- - -Each major Project API command is available through the ``tvmc micro`` -sub-command to make debugging interactions simple. Invoke ``tvmc micro --help`` -for more information. diff --git a/docs/arch/model_library_format.rst b/docs/arch/model_library_format.rst deleted file mode 100644 index 3ee6b9878f3f..000000000000 --- a/docs/arch/model_library_format.rst +++ /dev/null @@ -1,171 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _model_library_format: - -Model Library Format -==================== - -About Model Library Format --------------------------- - -TVM traditionally exports generated libraries as Dynamic Shared Objects (e.g. DLLs (Windows) or .so -(linux)). Inferences can be performed using those libraries by loading them into an executable using -``libtvm_runtime.so``. This process is very dependent on services provided by traditional OS. - -For deployment to unconventional platforms (e.g. those lacking traditional OS), TVM provides another -output format, Model Library Format. Initially, the microTVM project is the primary use case for this -format. Should it become useful in other use cases (and in particular, should it become possible to -export BYOC artifacts in Model Library Format), it could be used as a general-purpose TVM export -format. Model Library Format is a tarball containing a file for each piece of the TVM compiler -output. - -What can be Exported? ---------------------- - -At the time of writing, export is limited to full models built with ``tvm.relay.build``. - -Directory Layout ----------------- - -Model Library Format is contained within a tarball. All paths are relative to the root of the -tarball: - -- ``/`` - Root of the tarball - - - ``codegen`` - Root directory for all generated device code - - - (see `codegen`_ section) - - - ``executor-config/`` - Configuration for the executor which drives model inference - - - ``graph/`` - Root directory containing configuration for the GraphExecutor - - - ``graph.json`` - GraphExecutor JSON configuration - - - ``metadata.json`` - Machine-parseable metadata for this model - - - ``parameters/`` - Root directory where simplified parameters are placed - - - ``.params`` - Parameters for the model tvm.relay._save_params format - - - ``src/`` - Root directory for all source code consumed by TVM - - - ``relay.txt`` - Relay source code for the generated model - -Description of Sub-directories ------------------------------- - -.. _subdir_codegen: - -``codegen`` -^^^^^^^^^^^ - -All TVM-generated code is placed in this directory. At the time of writing, there is 1 file per -Module in the generated Module tree, though this restriction may change in the future. Files in -this directory should have filenames of the form ``/(lib|src)/.``. - -These components are described below: - - * ```` - Identifies the TVM target on which the code should run. Currently, only ``host`` - is supported. - * ```` - A unique slug identifying this file. Currently ``lib``, with ``>`` an - auto-incrementing integer. - * ```` - Suffix identifying the filename format. Currently ``c`` or ``o``. - -An example directory tree for a CPU-only model is shown below: - -- ``codegen/`` - Codegen directory - - - ``host/`` - Generated code for ``target_host`` - - - ``lib/`` - Generated binary object files - - - ``lib0.o`` - LLVM module (if ``llvm`` target is used) - - ``lib1.o`` - LLVM CRT Metadata Module (if ``llvm`` target is used) - - - ``src/`` - Generated C source - - - ``lib0.c`` - C module (if ``c`` target is used) - - ``lib1.c`` - C CRT Metadata module (if ``c`` target is used) - -``executor-config`` -^^^^^^^^^^^^^^^^^^^ - -Contains machine-parsable configuration for executors which can drive model inference. Currently, -only the GraphExecutor produces configuration for this directory, in ``graph/graph.json``. This -file should be read in and the resulting string supplied to the ``GraphExecutor()`` constructor for -parsing. - -``parameters`` -^^^^^^^^^^^^^^ - -Contains machine-parseable parameters. A variety of formats may be provided, but at present, only -the format produced by ``tvm.relay._save_params`` is supplied. When building with -``tvm.relay.build``, the ``name`` parameter is considered to be the model name. A single file is -created in this directory ``.json``. - -``src`` -^^^^^^^ - -Contains source code parsed by TVM. Currently, just the Relay source code is created in -``src/relay.txt``. - -Metadata --------- - -Machine-parseable metadata is placed in a file ``metadata.json`` at the root of the tarball. -Metadata is a dictionary with these keys: - -- ``export_datetime``: Timestamp when this Model Library Format was generated, in - `strftime `_ - format ``"%Y-%M-%d %H:%M:%SZ",``. -- ``memory``: A summary of the memory usage of each generated function. Documented in - `Memory Usage Summary`_. -- ``model_name``: The name of this model (e.g. the ``name`` parameter supplied to - ``tvm.relay.build``). -- ``executors``: A list of executors supported by this model. Currently, this list is always - ``["graph"]``. -- ``target``: A dictionary mapping ``device_type`` (the underlying integer, as a string) to the - sub-target which describes that relay backend used for that ``device_type``. -- ``version``: A numeric version number that identifies the format used in this Model Library - Format. This number is incremented when the metadata structure or on-disk structure changes. - This document reflects version ``5``. - -Memory Usage Summary -^^^^^^^^^^^^^^^^^^^^ - -A dictionary with these sub-keys: - - - ``"main"``: ``list[MainFunctionWorkspaceUsage]``. A list summarizing memory usage for each - workspace used by the main function and all sub-functions invoked. - - ``"operator_functions"``: ``map[string, list[FunctionWorkspaceUsage]]``. Maps operator function - name to a list summarizing memory usage for each workpace used by the function. - -A ``MainFunctionWorkspaceUsage`` is a dict with these keys: - -- ``"device"``: ``int``. The ``device_type`` associated with this workspace. -- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function - and all sub-functions invoked. -- ``"constants_size_bytes"``: ``int``. Size of the constants used by the main function. -- ``"io_size_bytes"``: ``int``. Sum of the sizes of the buffers used from this workspace by this - function and sub-functions. - -A ``FunctionWorkspaceUsage`` is a dict with these keys: - -- ``"device"``: ``int``. The ``device_type`` associated with this workspace. -- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function. diff --git a/docs/arch/relay_intro.rst b/docs/arch/relay_intro.rst deleted file mode 100644 index 87f68fcbce2e..000000000000 --- a/docs/arch/relay_intro.rst +++ /dev/null @@ -1,206 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _relay-dev-intro: - -Introduction to Relay IR -======================== -This article introduces Relay IR -- the second generation of NNVM. -We expect readers from two kinds of background -- those who have a programming language background and deep learning -framework developers who are familiar with the computational graph representation. - -We briefly summarize the design goal here, and will touch upon these points in the later part of the article. - -- Support traditional data flow-style programming and transformations. -- Support functional-style scoping, let-binding and making it a fully featured differentiable language. -- Being able to allow the user to mix the two programming styles. - -Build a Computational Graph with Relay --------------------------------------- -Traditional deep learning frameworks use computational graphs as their intermediate representation. -A computational graph (or dataflow graph), is a directed acyclic graph (DAG) that represents the computation. -Though dataflow graphs are limited in terms of the computations they are capable of expressing due to -lacking control flow, their simplicity makes it easier to implement automatic differentiation and -compile for heterogeneous execution environments (e.g., executing parts of the graph on specialized hardware). - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/relay/dataflow.png - :align: center - - -You can use Relay to build a computational (dataflow) graph. Specifically, the above code shows how to -construct a simple two-node graph. You can find that the syntax of the example is not that different from existing -computational graph IR like NNVMv1, with the only difference in terms of terminology: - -- Existing frameworks usually use graph and subgraph -- Relay uses function e.g. -- ``fn (%x)``, to indicate the graph - -Each dataflow node is a CallNode in Relay. The Relay Python DSL allows you to construct a dataflow graph quickly. -One thing we want to highlight in the above code -- is that we explicitly constructed an Add node with -both input point to ``%1``. When a deep learning framework evaluates the above program, it will compute -the nodes in topological order, and ``%1`` will only be computed once. -While this fact is very natural to deep learning framework builders, it is something that might -surprise a PL researcher in the first place. If we implement a simple visitor to print out the result and -treat the result as nested Call expression, it becomes ``log(%x) + log(%x)``. - -Such ambiguity is caused by different interpretations of program semantics when there is a shared node in the DAG. -In a normal functional programming IR, nested expressions are treated as expression trees, without considering the -fact that the ``%1`` is actually reused twice in ``%2``. - -The Relay IR is mindful of this difference. Usually, deep learning framework users build the computational -graph in this fashion, where a DAG node reuse often occurs. As a result, when we print out the Relay program in -the text format, we print one CallNode per line and assign a temporary id ``(%1, %2)`` to each CallNode so each common -node can be referenced in later parts of the program. - -Module: Support Multiple Functions (Graphs) -------------------------------------------- -So far we have introduced how can we build a dataflow graph as a function. One might naturally ask: Can we support multiple -functions and enable them to call each other? Relay allows grouping multiple functions together in a module; the code below -shows an example of a function calling another function. - -.. code:: - - def @muladd(%x, %y, %z) { - %1 = mul(%x, %y) - %2 = add(%1, %z) - %2 - } - def @myfunc(%x) { - %1 = @muladd(%x, 1, 2) - %2 = @muladd(%1, 2, 3) - %2 - } - -The Module can be viewed as a ``Map``. Here GlobalVar is just an id that is used to represent the functions -in the module. ``@muladd`` and ``@myfunc`` are GlobalVars in the above example. When a CallNode is used to call another function, -the corresponding GlobalVar is stored in the op field of the CallNode. It contains a level of indirection -- we need to look up -body of the called function from the module using the corresponding GlobalVar. In this particular case, we could also directly -store the reference to the Function as op in the CallNode. So, why do we need to introduce GlobalVar? The main reason is that -GlobalVar decouples the definition/declaration and enables recursion and delayed declaration of the function. - -.. code :: - - def @myfunc(%x) { - %1 = equal(%x, 1) - if (%1) { - %x - } else { - %2 = sub(%x, 1) - %3 = @myfunc(%2) - %4 = add(%3, %3) - %4 - } - } - -In the above example, ``@myfunc`` recursively calls itself. Using GlobalVar ``@myfunc`` to represent the function avoids -the cyclic dependency in the data structure. -At this point, we have introduced the basic concepts in Relay. Notably, Relay has the following improvements over NNVMv1: - -- Succinct text format that eases debugging of writing passes. -- First-class support for subgraphs-functions, in a joint module, this enables further chance of joint optimizations such as inlining and calling convention specification. -- Naive front-end language interop, for example, all the data structure can be visited in Python, which allows quick prototyping of optimizations in Python and mixing them with C++ code. - - -Let Binding and Scopes ----------------------- - -So far, we have introduced how to build a computational graph in the good old way used in deep learning frameworks. -This section will talk about a new important construct introduced by Relay -- let bindings. - -Let binding is used in every high-level programming language. In Relay, it is a data structure with three -fields ``Let(var, value, body)``. When we evaluate a let expression, we first evaluate the value part, assign -it to the var, then return the evaluated result in the body expression. - -You can use a sequence of let bindings to construct a logically equivalent program to a dataflow program. -The code example below shows one program with two forms side by side. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/relay/dataflow_vs_func.png - :align: center - - -The nested let binding is called A-normal form, and it is commonly used as IRs in functional programming languages. -Now, please take a close look at the AST structure. While the two programs are semantically identical -(so are their textual representations, except that A-normal form has let prefix), their AST structures are different. - -Since program optimizations take these AST data structures and transform them, the two different structures will -affect the compiler code we are going to write. For example, if we want to detect a pattern ``add(log(x), y)``: - -- In the data-flow form, we can first access the add node, then directly look at its first argument to see if it is a log -- In the A-normal form, we cannot directly do the check anymore, because the first input to add is ``%v1`` -- we will need to keep a map from variable to its bound values and look up that map, in order to know that ``%v1`` is a log. - -Different data structures will impact how you might write transformations, and we need to keep that in mind. -So now, as a deep learning framework developer, you might ask, Why do we need let bindings? -Your PL friends will always tell you that let is important -- as PL is a quite established field, -there must be some wisdom behind that. - -Why We Might Need Let Binding ------------------------------ -One key usage of let binding is that it specifies the scope of computation. Let us take a look at the following example, -which does not use let bindings. - -.. image:: https://raw.githubusercontent.com/tvmai/tvmai.github.io/main/images/relay/let_scope.png - :align: center - -The problem comes when we try to decide where we should evaluate node ``%1``. In particular, while the text format seems -to suggest that we should evaluate node ``%1`` outside the if scope, the AST(as shown in the picture) does not suggest so. -Actually, a dataflow graph never defines its scope of the evaluation. This introduces some ambiguity in the semantics. - -This ambiguity becomes more interesting when we have closures. Consider the following program, which returns a closure. -We don’t know where should we compute ``%1``; it can be either inside or outside the closure. - -.. code:: - - fn (%x) { - %1 = log(%x) - %2 = fn(%y) { - add(%y, %1) - } - %2 - } - -A let binding solves this problem, as the computation of the value happens at the let node. In both programs, -if we change ``%1 = log(%x)`` to ``let %v1 = log(%x)``, we clearly specify the computation location to -be outside of the if scope and closure. As you can see let-binding gives a more precise specification of the computation site -and could be useful when we generate backend code (as such specification is in the IR). - -On the other hand, the dataflow form, which does not specify the scope of computation, does have its own advantages --- namely, we don’t need to worry about where to put the let when we generate the code. The dataflow form also gives more freedom -to the later passes to decide where to put the evaluation point. As a result, it might not be a bad idea to use data flow -form of the program in the initial phases of optimizations when you find it is convenient. -Many optimizations in Relay today are written to optimize dataflow programs. - -However, when we lower the IR to an actual runtime program, we need to be precise about the scope of computation. -In particular, we want to explicitly specify where the scope of computation should happen when we are using -sub-functions and closures. Let-binding can be used to solve this problem in later stage execution specific optimizations. - - -Implication on IR Transformations ---------------------------------- - -Hopefully, by now you are familiar with the two kinds of representations. -Most functional programming languages do their analysis in A-normal form, -where the analyzer does not need to be mindful that the expressions are DAGs. - -Relay choose to support both the dataflow form and let bindings. We believe that it is important to let the -framework developer choose the representation they are familiar with. -This does, however, have some implications on how we write passes: - -- If you come from a dataflow background and want to handle lets, keep a map of var to the expressions so you can perform lookup when encountering a var. This likely means a minimum change as we already need a map from expressions to transformed expressions anyway. Note that this will effectively remove all the lets in the program. -- If you come from a PL background and like A-normal form, we will provide a dataflow to A-normal form pass. -- For PL folks, when you are implementing something (like a dataflow-to-ANF transformation), be mindful that expressions can be DAGs, and this usually means that we should visit expressions with a ``Map`` and only compute the transformed result once, so the resulting expression keeps the common structure. - -There are additional advanced concepts such as symbolic shape inference, polymorphic functions -that are not covered by this material; you are more than welcome to look at other materials. diff --git a/docs/arch/relay_op_strategy.rst b/docs/arch/relay_op_strategy.rst deleted file mode 100644 index dbac7c821827..000000000000 --- a/docs/arch/relay_op_strategy.rst +++ /dev/null @@ -1,282 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _relay-op-strategy: - -Relay Operator Strategy -======================= - -In order to lower Relay operators to the implementations defined in TOPI -library, a compute and schedule function need to be registered to each Relay -operator. However, compute and schedule functions are usually specialized for -each target, and further, even for the same target, we may have multiple -algorithms and implementations available. To deal with the complexity, we -introduce operator strategy to allow developers to define a flexible lowering -strategy for each operator and target. - - -Operator Strategy Design ------------------------- - -The basic element in operator strategy is an ``OpImplementation``. It includes -the a pair of compute and schedule function, the name of the implementation, -and a priority level (the use of priority level is explained in -`Select Implementation from Op Strategy`_). - -The ``OpStrategy`` includes a list of ``OpSpecialization``. Each ``OpSpecialization`` -contains a list of ``OpImplementation`` associated with a ``SpecializedCondition`` -(see definition in ``include/tvm/te/schedule.h``). The ``SpecializedCondition`` -can be null, indicating the implementations are generally applicable; -otherwise, the implementations are only considered when the specialized -condition is satisfied. ``SpecializedCondition`` consists of a list -of clauses defined in Tensor Expression in conjunctive normal form (CNF) and -only supports conditions on tensor shapes. - -Last, a strategy function, or ``FTVMStrategy``, determines which pair(s) of -compute and schedule functions should be used given a workload, and needs to be -registered to each Relay operator. ``FTVMStrategy`` is a generic function (see -``include/tvm/target/generic_func.h``), that can be overwritten for each -target. The function signature is - -.. code:: c - - OpStrategy(const Attrs& attrs, const Array& inputs, const Type& out_type, const Target& target) - -that the function returns an ``OpStrategy`` given the op attributes, input -tensors, output types, and target to compile to. - - -Write A Strategy Function -------------------------- - -We recommend developers to write strategy function in Python as -most TOPI compute and schedule functions are written in Python. -In python, we provide ``OpStrategy`` class in ``pyton/tvm/relay/op/op.py``. -It only has one API, which is to add an implementation to the strategy: - -.. code:: python - - def add_implementation(self, compute, schedule, name="default", plevel=10) - - -We now take ``topk`` as an example to explain how to write the -``FTVMStrategy`` function: - -.. code:: python - - # add to python/tvm/relay/op/strategy/generic.py - @override_native_generic_func("topk_strategy") - def topk_strategy(attrs, inputs, out_type, target): - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_topk(topi.topk), - wrap_topi_schedule(topi.generic.schedule_topk), - name="topk.generic") - return strategy - - # add to each target file in python/tvm/relay/op/strategy, e.g., x86.py, cuda.py, etc. - @topk_strategy.register(["cuda", "gpu"]) - def topk_strategy_cuda(attrs, inputs, out_type, target): - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_my_new_op(topi.cuda.topk), - wrap_topi_schedule(topi.cuda.schedule_topk), - name="topk.cuda") - return strategy - -In this example, we use ``topi.cuda.topk`` and ``topi.cuda.schedule_topk`` -as the compute and schedule function for CUDA or GPU target, while use TOPI -generic compute and schedule for the rest of targets. -Note that we use two wrapper functions that wrap the topi -compute and schedule to conform with the required function signature ( -see ``FTVMCompute`` and ``FTVMSchedule`` in ``include/tvm/relay/op_attr_types.h``). -Usually we need to write a customized compute wrapper function for each operator -to get different fields from op attributes. - -The example above shows a very basic strategy function that only -adds one implementation in the strategy. But for many complicated operators, -we may need to add multiple implementations that use different algorithms. -For example, we can use both direct and winograd algorithm to -compute a conv2d op. In order to achieve this, we can write the strategy function -as follows: - -.. code:: python - - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), - name="conv2d_nchw.cuda", - plevel=10) - - if winograd_condition: - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), - name="conv2d_nchw_winograd.cuda", - plevel=15) - -In this example, we add two implementations to the conv2d strategy where -winograd algorithm is only added when ``winograd_condition`` is true. -The implementation ``"conv2d_nchw_winograd.cuda"`` will be used to compile -conv2d when ``winograd_condition`` is true as it has higher -priority level (this could be changed if certain implementation is an AutoTVM -template. See `Select Implementation from Op Strategy`_ for more -details). Otherwise, ``"conv2d_nchw.cuda"`` is used. - -We can extend the example above to third party library implementation. For -example, we can add the implementation that invokes kernel in the cblas -library when cblas is included in the target. - -.. code:: python - - if "cblas" in target.libs: - strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_cblas), - wrap_topi_schedule(topi.x86.schedule_dense_cblas), - name="dense_cblas.x86", - plevel=15) - - -Further, we can add implementation specialized for a certain range of shapes. -The code below shows an example of dense strategy that adds an implementation -that is specialized for ``m`` greater than 16. The main difference between -hardcode python condition like examples above and specialized condition is that -it allows TVM to generate multiple kernels when the input tensors have symbolic -shapes. The compile engine will generate a dispatch function that invokes the -specialized kernel when the corresponding condition is met; otherwise, -invoke the kernel that has no associated specialized condition (``dense_common`` -in this example). This part is still work in progress. More details will be -provided after it is done. - -.. code:: python - - def dense_strategy(attrs, inputs, out_type, target): - m = inputs[0].shape[0] - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_dense(dense_compute1), - wrap_topi_schedule(dense_schedule1), - name="dense_common") - - with tvm.te.SpecializedCondition(m > 16): - strategy.add_implementation( - wrap_compute_dense(dense_compute2), - wrap_topi_schedule(dense_schedule2), - name="dense_for_large_m", - plevel=15) - - return strategy - - -Register Strategy Function to An Operator ------------------------------------------ - -After we define the strategy function for an operator, we can now -register the strategy function to this operator with - -.. code:: python - - register_strategy("topk", strategy.topk_strategy) - -However, it takes much effort to write a strategy function for an operator. -Therefore, we provide two other methods for simpler operators. - -First, for operators that have injective, broadcast, or reduction pattern, we -can call ``register_injective_schedule``, ``register_broadcast_schedule``, and -``register_reduce_schedule`` repsectively. The schedule function for these -patterns are already registered by each target and can be applied to these -operators. We assume the compute function should be the same across all targets, -and ``FTVMCompute`` needs to be registered to the op before invoking register -schedule. - -.. code:: python - - register_broadcast_schedule("add") - -Second, for operators that doesn't have these common patterns mentioned before, -but also have the same compute function for all targets, we can use -``register_schedule`` API. It is easier to write ``FTVMSchedule`` function -as we only need to provide which schedule function to use. The following -code snippet shows ``FTVMSchedule`` function for pooling. - -.. code:: python - - # add to python/tvm/relay/op/strategy/generic.py - @generic_func - def schedule_pool(attrs, outs, target): - with target: - return topi.generic.schedule_pool(outs, attrs.layout) - - # add to each target file in python/tvm/relay/op/strategy, e.g., x86.py, cuda.py, etc. - @schedule_pool.register("cpu") - def schedule_pool_cpu(attrs, outs, target): - ... - -After we created the ``FTVMSchedule`` for an operator, we can -register the strategy using ``register_schedule``: - -.. code:: python - - register_schedule("nn.max_pool2d", strategy.schedule_pool) - - -Register Strategies for A New Target ------------------------------------- - -There are two ways to register strategies for a new target. The more -straightforward one is adding a new target file in the directory -``python/tvm/relay/op/strategy``. You only need to customize the strategy for -ops that have been implemented for this new target and reuse the generic -strategies for the rest. - -Alternatively, you can also register the strategy for the new target outside the -TVM python library. The following code snippet shows an example how to do -so. You can find more examples in ``vta/python/vta/top/op.py``. - -.. code:: python - - @relay.op.strategy.conv2d_strategy.register("mytarget") - def conv2d_strategy_mytarget(attrs, inputs, out_type, target): - ... - - -Select Implementation from Op Strategy --------------------------------------- - -During the compilation, Relay compile engine needs to determine which -implementation to use for an operator when there are multiple. The selection -policy works as follows. - -When the input tensors to an operator or a fused op all have constant shapes, -the compile engine first finds the best implementation based on AutoTVM tuning -logs. If there is no implementation that is an AutoTVM template or all AutoTVM -templates have fallback configs, the implementation with highest priority level -will then be chosen. Implementations with same priority level in this case leads -to an undefined behavior, and any of them might be selected. - -The selection policy for ops with symbolic input shapes is still work in -progress. Currently, if any input tensor has a symbolic shape, only the -implementation with highest priority level will be used for this operator. This -will be updated after the implementation finishes. - -For debug purpose, you can add the following lines before you compile the Relay -model to learn which implementation is used for each operator. - -.. code:: python - - logging.getLogger("te_compiler").setLevel(logging.INFO) - logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout)) diff --git a/docs/arch/virtual_machine.rst b/docs/arch/virtual_machine.rst deleted file mode 100644 index c532392afeb8..000000000000 --- a/docs/arch/virtual_machine.rst +++ /dev/null @@ -1,410 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Putting the VM in TVM: The Relay Virtual Machine -================================================ - -Relay, a new program representation, has enabled the representation and optimization of -a great breadth of machine learning programs. -Unfortunately, by supporting a more expressive set of programs, we have -introduced several new execution challenges. - -Relay's interpreter can execute the full language but has notable limitations -that make it unsuited for production deployments. It is structured as an inefficient -interpreter that performs AST traversal to execute the program. This approach is conceptually -simple but inefficient, as the AST traversal heavily relies on indirection. - -There are further challenges in compiling dynamic code, such as dynamic scheduling and allocation, -fully dynamic tensor shapes, and control flow. The interpreter offers simple solutions -for these, but none is sufficiently compelling or optimized. - -The second execution mechanism is the existing graph executor. In order to target Relay -programs to this, we compile a small subset of them to the old graph format and execute -them on the runtime. Graph executor provides a fast execution experience but only for a very limited -subset of Relay programs. - -An alternative but not-standard approach is Relay's ahead-of-time compiler, -which compiles a Relay program into a shared library containing an ahead-of-time -implementation. The ahead-of-time compiler provides compelling performance -but is difficult to extend and instrument, which can only be done by modifying the -code generation and optimization mechanisms. - -The Relay virtual machine is intended to be a framework that balances these competing -approaches, providing a dynamic execution environment which can be extended, instrumented, -and integrated with other approaches like ahead-of-time compilation via a flexible extension -mechanism. - -The virtual machine is designed to strike a balance between performance and flexibility -when deploying and executing Relay programs, without giving up the benefits of TVM. - -Virtual machine (VM) design is a well-studied area in programming languages and systems, -and there have been various virtual machine designs for both full-fledged -and embedded programing languages. -Previous language VM designs have been heavily tailored to the execution profile of traditional programs. -Traditional programs manipulate small scalar values and consist of a large number of low-level instructions. -The sheer quantity of instructions requires instruction execution and dispatch to be extremely efficient. -In the context of machine learning we manipulate primarily tensor values, using a (relatively) -low number of high level instructions. ML programs' cost centers are expensive operator invocations, -such as GEMM or convolution, over a large input. Due to the execution profile exhibited by ML programs, -micro-optimizations present in scalar VMs are dramatically less important. - -TVM has provided strong support for vision models, -but we want to grow to support a wider variety of models. -The graph executor is able to utilize the fully static nature of the input graphs to perform -aggressive optimization such as fully static allocation, and optimal memory reuse. -When we introduce models which make use of control flow, recursion, dynamic shapes, and dynamic -allocation, we must change how execution works. A virtual machine for Relay is a natural choice. - -The rest of this document provides a high-level overview of the Relay -virtual machine design and its instruction set. - -Design ------- - -The VM's design is focused on simplicity without sacrificing performance. -In order to accomplish this we have focused on designing a tensor VM rather than a scalar VM. - -In the tensor VM setting, we optimize for cheap “allocation” of objects (by trying to avoid real allocation), -reuse of static fragments, and the ability to do dynamic shape (i.e jagged tensors). - -Instruction Set -~~~~~~~~~~~~~~~ - -The choices of an instruction set and instruction representation are the most critical design decisions for a VM. -The current representation of the instructions is a tagged union containing the op-code and the data payload. An important design decision is the level of abstraction of the instructions (RISC vs. CISC) and how they take their data (fixed-width instruction encoding vs. variable-length encoding). The current version is closer to CISC, with complex instructions like AllocTensor, and is variable-length due to the inclusion of the shape as part of the instruction. The current instruction set is very high-level and corresponds roughly to high-level operations in Relay. - -Ret -^^^ -**Arguments**: -:: - - RegName dst - RegName result - -Returns the object in register ``result`` to caller's register ``dst``. - -InvokePacked -^^^^^^^^^^^^ -**Arguments**: -:: - - Index packed_index - Index arity - Index output_size - RegName* packed_args - -Invoke the packed function denoted by ``packed_index``. The ``arity`` -and ``output_size`` are used to inform the VM how many inputs and -outputs to expect. ``packed_args`` stores the list of argument registers. Note ``Index`` -is an alias of ``int64_t``, and it will be used in other instructions as well. - -AllocTensor -^^^^^^^^^^^ -**Arguments**: -:: - - RegName dst - RegName storage - uint32_t ndim - int64_t* shape - DLDataType dtype - -Allocate a tensor value of using constant shape (stored in ``shape``) and ``dtype`` -from the given storage block, ``storage``. The result is saved to register ``dst``. - -AllocTensorReg -^^^^^^^^^^^^^^ -**Arguments**: -:: - - RegName dst - RegName storage - RegName shape_register - DLDataType dtype - -Allocate a tensor value of the appropriate shape (stored in ``shape_register``) -and ``dtype`` from the given storage block (stored in ``storage``). The result is saved to register ``dst``. - -AllocStorage -^^^^^^^^^^^^ -**Arguments**: -:: - - RegName dst - RegName size - RegName alignment - DLDataType dtype_hint - -Allocate a storage block with the given ``size``, ``alignment`` and data type, ``dtype_hint``. -The allocated storage block is stored in register ``dst``. - -AllocADT -^^^^^^^^ -**Arguments**: -:: - - RegName dst - Index tag - Index num_fields - RegName* datatype_fields - -Allocate a data type with the tag ``tag`` using the ``num_fields`` entries -from registers ``datatype_fields``. The result is saved to register ``dst``. - -AllocClosure -^^^^^^^^^^^^ -**Arguments**: -:: - - RegName dst - Index clo_index - Index num_freevar - RegName* free_vars; - -Allocate a closure with the VMFunction at ``clo_index`` as -its code, and the ``num_freevar`` entries from registers in -``free_vars``. The result is saved to register ``dst``. - -GetField -^^^^^^^^ -**Arguments**: -:: - - RegName dst - RegName object - Index field_index - -Get the field value with index ``field_index`` from ``object``. And saves the result to register ``dst``. - -If -^^ -**Arguments**: -:: - - RegName test - RegName target - Index true_offset - Index false_offset - -Check if the object at register ``test`` is equal to ``target``. -If equal, relative jump by ``true_offset``, else relative -jump by ``false_offset``. - -GetTag -^^^^^^ -**Arguments**: -:: - - RegName object - RegName dst - -Get the object tag for ADT object in register ``object``. And saves the reult to register ``dst``. - -Fatal -^^^^^ -Fail the virtual machine execution. - -Goto -^^^^ -**Arguments**: -:: - - Index pc_offset - -Relative unconditional jump by ``pc_offset``. - -Invoke -^^^^^^ -**Arguments**: -:: - - Index func_index - -Invoke function at ``func_index``, consumes the number of arguments contained in the VMFunction's -arity field. - -InvokeClosure -^^^^^^^^^^^^^ -**Arguments**: -:: - - RegName closure - Index num_closure_args - RegName* closure_args - -Invokes ``closure``, consuming the number of arguments declared in the closure's VMFunction. - -LoadConst -^^^^^^^^^ -**Arguments**: -:: - - RegName dst - Index const_index - -Load the constant at ``const_index`` from the constant pool. The result is saved to register ``dst``. - -LoadConsti -^^^^^^^^^^ -**Arguments**: -:: - - Index val - RegName dst - -Load the constant integer ``val`` to register ``dst``. The result is a 0-rank tensor. - -Object Representation -~~~~~~~~~~~~~~~~~~~~~ -We leverage the object protocol to represent the objects that are used by the -VM. - -Currently, three types of objects, ``NDArray``, ``ADT``, and ``Closure`` objects, are used -to represent tensor, tuple/list, and closure data, respectively. More details -for each of them can be found at `include/tvm/runtime/ndarray.h`_, -`include/tvm/runtime/vm/vm.h`_, and `include/tvm/runtime/container.h`_, respectively. - -.. _include/tvm/runtime/ndarray.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/ndarray.h - -.. _include/tvm/runtime/vm/vm.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/vm/vm.h - -.. _include/tvm/runtime/container.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/container.h - -Stack and State -~~~~~~~~~~~~~~~ - -The Relay VM maintains a stack frame, which contains information about how to resume the -previous call. Registers are allocated in a continuous space (virtual register file) for each function. - -We keep track of a set of Relay functions we have called, a pointer into its bytecode, an offset into the byte code (known as the program counter). - -.. code-block:: c - - struct VirtualMachine { - ... - std::vector frames; - ... - // Current function. - size_t func_index; - // Pointer into the current function's instructions. - const Instruction* code; - // Current program counter relative to the code pointer. - size_t pc; - ... - }; - - -Dispatch Loop -~~~~~~~~~~~~~ -A critical piece of a VM is the dispatch loop. The dispatch loop usually dominates the execution time of a -virtual machine, but we have experimentally found this not to be the case for Relay. We have just implemented -a simple ``switch``/``goto`` dispatch loop which dispatches based on instruction op code. - -This loop is implemented by ``VirtualMachine::Run()``. - -VM Compiler -~~~~~~~~~~~ - -An important part of this infrastructure is a compiler from Relay's full IR into a sequence of bytecode. -The VM compiler transforms a ``tvm::relay::Module`` into a ``tvm::relay::vm::Executable``. The executable -contains a set of compiled functions, the compiled functions are contained in ``tvm::relay::vm::Function``. -The functions contain metadata about the function as well as its compiled bytecode. The emitted executable -object then can be loaded and run by a ``tvm::relay::vm::VirtualMachine`` object. For full definitions of the -data structures, please see `include/tvm/runtime/vm/executable.h`_ and `include/tvm/runtime/vm/vm.h`_. - -.. _include/tvm/runtime/vm/executable.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/vm/executable.h - -Optimizations -~~~~~~~~~~~~~ - -There are quite a few optimizations required by the VM compiler. Each of them -is implemented as a pass which is managed by the Relay pass manager. - -Optimizations marked with `TODO` are not implemented yet. - -- A-Normal Form -- Lambda Lift (see `src/relay/vm/lambda_lift.cc`_) -- Inline Primitives (see `src/relay/vm/inline_primitives.cc`_) -- Constant Pool Layout (see `src/relay/backend/vm/compiler.cc`_) -- Tail Call Optimization (TODO) -- Liveness Analysis (TODO) - -.. _src/relay/vm/lambda_lift.cc: https://github.com/apache/tvm/blob/main/src/relay/backend/vm/lambda_lift.cc - -.. _src/relay/vm/inline_primitives.cc: https://github.com/apache/tvm/blob/main/src/relay/backend/vm/inline_primitives.cc - -.. _src/relay/backend/vm/compiler.cc: https://github.com/apache/tvm/blob/main/src/relay/backend/vm/compiler.cc - -Serialization -~~~~~~~~~~~~~ - -Serializing and deserializing the executable generated by the Relay VM compiler is a must as -we may want to save the model to the disk and perform inference later. Previously, Relay has produced -a serialized form in a json file for the graph executor. However, the same format is not directly -applicable to the VM as it emits bytecode instead of graph-style programs. -Serialization of an executable essentially needs to handle both model specific -(i.e. weights and kernels) and VM related (i.e. bytecode and global function names) data. - -For kernels, we can conveniently leverage existing TVM infra to save and load -the compiled library module. Here we only focus on serializing other several -components in a binary format that is organized with the following sections in order. - -- Global section. This section contains the globals (function names) used by the virtual machine. - -- Constant section. This section is used to store the constant pool (i.e. weights of the model) - for a virtual machine. - -- Primitive name section. This section is introduced to accommodate the list of primitive - operator names that will be invoked by the virtual machine, i.e. the names - starting with ``fused_``. The primitive names are used as symbols to look up - function pointers in the compiled kernel library. - -- Code section. The VM functions, including bytecode, are sitting in this section. The dispatching - loop iterates through this section to fetch instructions for execution. - -Hence, unlike the graph executor artifact that contains weight (.params), graph json (.json), -and compiled kernel library (.so), the serialized executable artifact is composed of the Relay -object file (.ro) and the compiled kernel library (.so). - -A ``save`` function is implemented to store the executable to the disk and -serialize it into the above format. Meanwhile, a ``load_exec`` function is used to -load the serialized kernel binary and executable related binary code, which will be again used to -instantiate a VM object. Please refer to the `test_vm_serialization.py`_ file for more -examples. - -.. _test_vm_serialization.py: https://github.com/apache/tvm/blob/main/tests/python/relay/test_vm_serialization.py - -Unresolved Questions -~~~~~~~~~~~~~~~~~~~~ - -How do we handle dynamic shapes? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Dynamic shape support is ongoing work in TVM as we upgrade Relay, TVM's compiler. For the most recent updates on -dynamic shape support, we recommend following updates in TVM's Discuss forum (https://discuss.tvm.apache.org/). - -How can we modify the VM to support JIT compilation of certain code paths? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the code generation space there are still many tradeoffs to be analyzed and the VM is designed -to be very flexible so we can modify it for future experiments. - -How do we support heterogenous execution? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Heterogenous execution should work out of the box assuming we have annotated the appropriate device copies. -In order to do this properly we need to run the device annotation and copying passes. diff --git a/docs/deep_dive/relax/index.rst b/docs/deep_dive/relax/index.rst index f891eb2793ec..2b7c4ea599ae 100644 --- a/docs/deep_dive/relax/index.rst +++ b/docs/deep_dive/relax/index.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -.. _relax: +.. _relax-deep-dive: Relax ===== diff --git a/docs/deep_dive/tensor_ir/index.rst b/docs/deep_dive/tensor_ir/index.rst index 46bed7c42319..66e153ec01a5 100644 --- a/docs/deep_dive/tensor_ir/index.rst +++ b/docs/deep_dive/tensor_ir/index.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -.. _tensor-ir: +.. _tensor-ir-deep-dive: TensorIR ======== diff --git a/docs/dev/tutorial/codebase_walkthrough.rst b/docs/dev/tutorial/codebase_walkthrough.rst index 726e253057d0..a349b69f7b58 100644 --- a/docs/dev/tutorial/codebase_walkthrough.rst +++ b/docs/dev/tutorial/codebase_walkthrough.rst @@ -124,7 +124,7 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu stmt = schedule.ScheduleOps(sch, bounds) ... -Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see :ref:`dev-InferBound-Pass`. +Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. ``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``. diff --git a/docs/index.rst b/docs/index.rst index 2102bdd33a00..3abc39e82fd1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -54,6 +54,7 @@ driving its costs down. :maxdepth: 2 :caption: Deep Dive + arch/index deep_dive/tensor_ir/index deep_dive/relax/index @@ -73,7 +74,6 @@ driving its costs down. dev/tutorial/index dev/how_to/how_to.rst reference/langref/index - arch/index topic/microtvm/index topic/vta/index diff --git a/docs/reference/langref/relay_expr.rst b/docs/reference/langref/relay_expr.rst index c50acc2949dd..c789331efe63 100644 --- a/docs/reference/langref/relay_expr.rst +++ b/docs/reference/langref/relay_expr.rst @@ -540,9 +540,7 @@ the graph node will only be evaluated once by the compiled program. These bindings allow for a style of programming that corresponds to that already employed by NNVM and other dataflow graph-based input formats. The fact that the variables are not scoped offers some flexibility in evaluation order compared to :code:`let` -bindings, though this can also introduce some ambiguity in programs (the -:ref:`developer introduction to the Relay IR` includes more detailed discussion -of this nuance). +bindings, though this can also introduce some ambiguity in programs. *Note: Graph bindings are not currently parsed by the text format.* diff --git a/docs/topic/microtvm/index.rst b/docs/topic/microtvm/index.rst index 4dd4ab5d511d..2bac70241d3b 100644 --- a/docs/topic/microtvm/index.rst +++ b/docs/topic/microtvm/index.rst @@ -58,13 +58,6 @@ more as they follow through them. Here is a list of tutorials that you can start 3. Try running a more complex tutorial: :ref:`Creating Your MLPerfTiny Submission with microTVM `. -How microTVM Works -~~~~~~~~~~~~~~~~~~ - - -You can read more about the design of these pieces at the :ref:`microTVM Design Document `. - - Help and Discussion ~~~~~~~~~~~~~~~~~~~ diff --git a/gallery/how_to/tune_with_autoscheduler/tune_network_arm.py b/gallery/how_to/tune_with_autoscheduler/tune_network_arm.py index d795c3aba245..e4edf0333508 100644 --- a/gallery/how_to/tune_with_autoscheduler/tune_network_arm.py +++ b/gallery/how_to/tune_with_autoscheduler/tune_network_arm.py @@ -70,7 +70,6 @@ # with any layout, we found the best performance is typically achieved with NHWC layout. # We also implemented more optimizations for NHWC layout with the auto-scheduler. # So it is recommended to convert your models to NHWC layout to use the auto-scheduler. -# You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): diff --git a/gallery/how_to/tune_with_autoscheduler/tune_network_cuda.py b/gallery/how_to/tune_with_autoscheduler/tune_network_cuda.py index 1f8c0cc13a35..f11aef253f81 100644 --- a/gallery/how_to/tune_with_autoscheduler/tune_network_cuda.py +++ b/gallery/how_to/tune_with_autoscheduler/tune_network_cuda.py @@ -64,7 +64,6 @@ # with any layout, we found the best performance is typically achieved with NHWC layout. # We also implemented more optimizations for NHWC layout with the auto-scheduler. # So it is recommended to convert your models to NHWC layout to use the auto-scheduler. -# You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. def get_network(name, batch_size, layout="NHWC", dtype="float32"): diff --git a/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py b/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py index 15f337901360..3120c30cef1a 100644 --- a/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py +++ b/gallery/how_to/tune_with_autoscheduler/tune_network_mali.py @@ -67,7 +67,6 @@ # with any layout, we found the best performance is typically achieved with NHWC layout. # We also implemented more optimizations for NHWC layout with the auto-scheduler. # So it is recommended to convert your models to NHWC layout to use the auto-scheduler. -# You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. def get_network(name, batch_size, layout="NHWC", dtype="float32"): diff --git a/gallery/how_to/tune_with_autoscheduler/tune_network_x86.py b/gallery/how_to/tune_with_autoscheduler/tune_network_x86.py index 169567122f79..43314a4b0a2f 100644 --- a/gallery/how_to/tune_with_autoscheduler/tune_network_x86.py +++ b/gallery/how_to/tune_with_autoscheduler/tune_network_x86.py @@ -67,7 +67,6 @@ # with any layout, we found the best performance is typically achieved with NHWC layout. # We also implemented more optimizations for NHWC layout with the auto-scheduler. # So it is recommended to convert your models to NHWC layout to use the auto-scheduler. -# You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): diff --git a/gallery/how_to/work_with_microtvm/micro_tvmc.sh b/gallery/how_to/work_with_microtvm/micro_tvmc.sh index dded94e55603..bf9338cf5f7f 100755 --- a/gallery/how_to/work_with_microtvm/micro_tvmc.sh +++ b/gallery/how_to/work_with_microtvm/micro_tvmc.sh @@ -96,7 +96,7 @@ wget https://github.com/tensorflow/tflite-micro/raw/a56087ffa2703b4d5632f024a8a4 # # Model Library Format (MLF) is an output format that TVM provides for micro targets. MLF is a tarball # containing a file for each piece of the TVM compiler output which can be used on micro targets outside -# TVM environment. Read more about :ref:`Model Library Format `. +# TVM environment. # # Here, we generate a MLF file for ``qemu_x86`` Zephyr board. You can chooses `aot` or `graph` executor type # to run this tutorial, however, we recommend to use `aot` for microTVM targets since `aot` uses ahead of time From 2a87c4cfc075b2cce18738cc270a2229cfb50de7 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Mon, 23 Sep 2024 21:42:37 -0400 Subject: [PATCH 168/202] [BYOC][NNAPI] Add NNAPI backend for BYOC (#17385) * [BYOC][NNAPI] This PR intorduce NNAPI to TVM This PR introduces a new BYOC backend for Android Neural Networks API (NNAPI), enabling execution of neural networks on custom accelerators. This feature adds a new codegen and runtime for NNAPI, supporting operations such as element-wise ops, nn.dense, and nn.conv2d for CNN model with static shape. Co-authored-by: Ming-Long Huang Co-authored-by: HMZ --- CMakeLists.txt | 3 + cmake/modules/LibInfo.cmake | 2 + cmake/modules/contrib/NNAPI.cmake | 39 ++ python/tvm/relax/backend/contrib/nnapi.py | 324 ++++++++++ python/tvm/testing/utils.py | 6 + src/relax/backend/contrib/nnapi/codegen.cc | 272 ++++++++ src/runtime/contrib/nnapi/nnapi_builder.cc | 264 ++++++++ src/runtime/contrib/nnapi/nnapi_builder.h | 133 ++++ src/runtime/contrib/nnapi/nnapi_ops.cc | 601 ++++++++++++++++++ src/runtime/contrib/nnapi/nnapi_ops.h | 165 +++++ src/runtime/contrib/nnapi/nnapi_runtime.cc | 250 ++++++++ src/support/libinfo.cc | 10 + tests/python/nightly/test_nnapi/__init__.py | 17 + tests/python/nightly/test_nnapi/conftest.py | 39 ++ .../nightly/test_nnapi/infrastructure.py | 143 +++++ .../python/nightly/test_nnapi/test_network.py | 136 ++++ tests/python/nightly/test_nnapi/test_ops.py | 362 +++++++++++ 17 files changed, 2766 insertions(+) create mode 100644 cmake/modules/contrib/NNAPI.cmake create mode 100644 python/tvm/relax/backend/contrib/nnapi.py create mode 100644 src/relax/backend/contrib/nnapi/codegen.cc create mode 100644 src/runtime/contrib/nnapi/nnapi_builder.cc create mode 100644 src/runtime/contrib/nnapi/nnapi_builder.h create mode 100644 src/runtime/contrib/nnapi/nnapi_ops.cc create mode 100644 src/runtime/contrib/nnapi/nnapi_ops.h create mode 100644 src/runtime/contrib/nnapi/nnapi_runtime.cc create mode 100644 tests/python/nightly/test_nnapi/__init__.py create mode 100644 tests/python/nightly/test_nnapi/conftest.py create mode 100644 tests/python/nightly/test_nnapi/infrastructure.py create mode 100644 tests/python/nightly/test_nnapi/test_network.py create mode 100644 tests/python/nightly/test_nnapi/test_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 38dd59b9c906..66ea6a07da85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,6 +125,8 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "Build with Arm Compute Library graph executor" OFF) tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) +tvm_option(USE_NNAPI_CODEGEN "Build with NNAPI Codegen support" OFF) +tvm_option(USE_NNAPI_RUNTIME "Build with NNAPI runtime" OFF) tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF) tvm_option(USE_VITIS_AI "Build with VITIS-AI Codegen support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) @@ -602,6 +604,7 @@ include(cmake/modules/contrib/BNNS.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) +include(cmake/modules/contrib/NNAPI.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) include(cmake/modules/contrib/UMA.cmake) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index a2b51bb33195..ee6561dffce8 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -144,6 +144,8 @@ function(add_lib_info src_file) TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}" + TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}" + TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}" TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}" ) diff --git a/cmake/modules/contrib/NNAPI.cmake b/cmake/modules/contrib/NNAPI.cmake new file mode 100644 index 000000000000..23eb6dd11eda --- /dev/null +++ b/cmake/modules/contrib/NNAPI.cmake @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NNAPI Codegen +if(USE_NNAPI_CODEGEN) + message(STATUS "Build with NNAPI codegen") + + tvm_file_glob(GLOB COMPILER_NNAPI_SRCS src/relax/backend/contrib/nnapi/*.cc) + tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc) + list(APPEND COMPILER_SRCS ${COMPILER_NNAPI_SRCS}) + if(NOT USE_NNAPI_RUNTIME) + list(APPEND COMPILER_SRCS ${RUNTIME_NNAPI_SRCS}) + endif() +endif() + +# NNAPI Runtime +if(USE_NNAPI_RUNTIME) + message(STATUS "Build with NNAPI runtime") + + tvm_file_glob(GLOB RUNTIME_NNAPI_SRCS src/runtime/contrib/nnapi/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_NNAPI_SRCS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS neuralnetworks log) + + add_definitions(-DTVM_GRAPH_EXECUTOR_NNAPI) +endif() diff --git a/python/tvm/relax/backend/contrib/nnapi.py b/python/tvm/relax/backend/contrib/nnapi.py new file mode 100644 index 000000000000..6e428b60d584 --- /dev/null +++ b/python/tvm/relax/backend/contrib/nnapi.py @@ -0,0 +1,324 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for NNAPI backend""" +from typing import ( + Mapping, + Optional, + Tuple, + List, +) +from tvm.ir import IRModule +from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions +from tvm.relax.dpl.pattern import ( + DFPattern, + wildcard, + is_op, +) + +from ..pattern_registry import get_patterns_with_prefix, register_patterns + + +def elementwise_binary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]: + """ + Returns a list of tuples representing elementwise binary operation patterns mapped + between NNAPI and Relax frameworks. + """ + + def _elementwise_binary_pattern( + pattern_name: str, + op_name: str, + ) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + input0 = wildcard() + input1 = wildcard() + + pattern = is_op(op_name)(input0, input1) + + return (pattern_name, pattern, {}) + + return [ + _elementwise_binary_pattern("nnapi.add", "relax.add"), + _elementwise_binary_pattern("nnapi.mul", "relax.multiply"), + _elementwise_binary_pattern("nnapi.div", "relax.divide"), + _elementwise_binary_pattern("nnapi.sub", "relax.subtract"), + _elementwise_binary_pattern("nnapi.pow", "relax.power"), + _elementwise_binary_pattern("nnapi.equal", "relax.equal"), + _elementwise_binary_pattern("nnapi.greater", "relax.greater"), + _elementwise_binary_pattern("nnapi.greater_equal", "relax.greater_equal"), + _elementwise_binary_pattern("nnapi.less", "relax.less"), + _elementwise_binary_pattern("nnapi.less_equal", "relax.less_equal"), + _elementwise_binary_pattern("nnapi.not_equal", "relax.not_equal"), + _elementwise_binary_pattern("nnapi.maximum", "relax.maximum"), + _elementwise_binary_pattern("nnapi.minimum", "relax.minimum"), + ] + + +def unary_patterns() -> List[Tuple[str, DFPattern, Mapping[str, DFPattern]]]: + """ + Returns a list of tuples representing unary operation patterns mapped + between NNAPI and Relax frameworks. + """ + + def _unary_pattern( + pattern_name: str, op_name: str + ) -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + input0 = wildcard() + pattern = is_op(op_name)(input0) + return (pattern_name, pattern, {}) + + return [ + _unary_pattern("nnapi.floor", "relax.floor"), + _unary_pattern("nnapi.relu", "relax.nn.relu"), + _unary_pattern("nnapi.logistic", "relax.sigmoid"), + _unary_pattern("nnapi.softmax", "relax.nn.softmax"), + _unary_pattern("nnapi.tanh", "relax.tanh"), + _unary_pattern("nnapi.abs", "relax.abs"), + _unary_pattern("nnapi.exp", "relax.exp"), + _unary_pattern("nnapi.log", "relax.log"), + _unary_pattern("nnapi.neg", "relax.negative"), + _unary_pattern("nnapi.cast", "relax.astype"), + _unary_pattern("nnapi.sqrt", "relax.sqrt"), + _unary_pattern("nnapi.rsqrt", "relax.rsqrt"), + ] + + +def matmul_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing matmul operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + input1 = wildcard() + pattern = is_op("relax.matmul")(input0, input1) + return ("nnapi.batch_matmul", pattern, {}) + + +def permute_dims_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing permute operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.permute_dims")(input0) + return ("nnapi.transpose", pattern, {}) + + +def astype_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing astype operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard().has_dtype("float16") | wildcard().has_dtype("float32") + pattern = is_op("relax.astype")(input0).has_dtype("float16") | is_op("relax.astype")( + input0 + ).has_dtype("float32") + + return ("nnapi.cast", pattern, {}) + + +def mean_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing mean operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.mean")(input0) + + return ("nnapi.mean", pattern, {}) + + +def conv2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing conv2d operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + input1 = wildcard() + input2 = wildcard() + conv = is_op("relax.nn.conv2d")(input0, input1) + pattern = is_op("relax.add")(conv, input2) + return ("nnapi.conv2d", pattern, {}) + + +def max_pool2d_pattern() -> Tuple[str, DFPattern, Mapping[str, DFPattern]]: + """ + Returns a tuple representing max_pool2d operation patterns mapped + between NNAPI and Relax frameworks. + """ + input0 = wildcard() + pattern = is_op("relax.nn.max_pool2d")(input0) + return ("nnapi.max_pool_2d", pattern, {}) + + +register_patterns( + [ + *elementwise_binary_patterns(), + *unary_patterns(), + matmul_pattern(), + permute_dims_pattern(), + astype_pattern(), + mean_pattern(), + conv2d_pattern(), + max_pool2d_pattern(), + ] +) + + +def min_feature_level(pattern_name: str) -> int: + """ + Returns the minimum feature level required to support a given NNAPI operation pattern. + + Args: + pattern_name (str): The name of the NNAPI operation pattern + (e.g., "nnapi.add", "nnapi.conv2d"). + + Returns: + int: The minimum feature level for the specified pattern, or 1 if the pattern is not found. + """ + + levels = { + "nnapi.add": 1, + "nnapi.average_pool_2d": 1, + "nnapi.concatenation": 1, + "nnapi.conv2d": 1, + "nnapi.depthwise_conv_2d": 1, + "nnapi.depth_to_space": 1, + "nnapi.dequantize": 1, + "nnapi.embedding_lookup": 1, + "nnapi.floor": 1, + "nnapi.fully_connected": 1, + "nnapi.hashtable_lookup": 1, + "nnapi.l2_normalization": 1, + "nnapi.l2_pool_2d": 1, + "nnapi.local_response_normalization": 1, + "nnapi.logistic": 1, + "nnapi.lsh_projection": 1, + "nnapi.lstm": 1, + "nnapi.max_pool_2d": 1, + "nnapi.mul": 1, + "nnapi.relu": 1, + "nnapi.relu1": 1, + "nnapi.relu6": 1, + "nnapi.reshape": 1, + "nnapi.resize_bilinear": 1, + "nnapi.rnn": 1, + "nnapi.softmax": 1, + "nnapi.space_to_depth": 1, + "nnapi.svdf": 1, + "nnapi.tanh": 1, + "nnapi.batch_to_space_nd": 2, + "nnapi.div": 2, + "nnapi.mean": 2, + "nnapi.pad": 2, + "nnapi.space_to_batch_nd": 2, + "nnapi.squeeze": 2, + "nnapi.strided_slice": 2, + "nnapi.sub": 2, + "nnapi.transpose": 2, + "nnapi.abs": 3, + "nnapi.argmax": 3, + "nnapi.argmin": 3, + "nnapi.axis_aligned_bbox_transform": 3, + "nnapi.bidirectional_sequence_lstm": 3, + "nnapi.bidirectional_sequence_rnn": 3, + "nnapi.box_with_nms_limit": 3, + "nnapi.cast": 3, + "nnapi.channel_shuffle": 3, + "nnapi.detection_postprocessing": 3, + "nnapi.equal": 3, + "nnapi.exp": 3, + "nnapi.expand_dims": 3, + "nnapi.gather": 3, + "nnapi.generate_proposals": 3, + "nnapi.greater": 3, + "nnapi.greater_equal": 3, + "nnapi.grouped_conv_2d": 3, + "nnapi.heatmap_max_keypoint": 3, + "nnapi.instance_normalization": 3, + "nnapi.less": 3, + "nnapi.less_equal": 3, + "nnapi.log": 3, + "nnapi.logical_and": 3, + "nnapi.logical_not": 3, + "nnapi.logical_or": 3, + "nnapi.log_softmax": 3, + "nnapi.maximum": 3, + "nnapi.minimum": 3, + "nnapi.neg": 3, + "nnapi.not_equal": 3, + "nnapi.pad_v2": 3, + "nnapi.pow": 3, + "nnapi.prelu": 3, + "nnapi.quantize": 3, + "nnapi.quantized_16bit_lstm": 3, + "nnapi.random_multinomial": 3, + "nnapi.reduce_all": 3, + "nnapi.reduce_any": 3, + "nnapi.reduce_max": 3, + "nnapi.reduce_min": 3, + "nnapi.reduce_prod": 3, + "nnapi.reduce_sum": 3, + "nnapi.roi_align": 3, + "nnapi.roi_pooling": 3, + "nnapi.rsqrt": 3, + "nnapi.select": 3, + "nnapi.sin": 3, + "nnapi.slice": 3, + "nnapi.split": 3, + "nnapi.sqrt": 3, + "nnapi.tile": 3, + "nnapi.topk_v2": 3, + "nnapi.transpose_conv_2d": 3, + "nnapi.unidirectional_sequence_lstm": 3, + "nnapi.unidirectional_sequence_rnn": 3, + "nnapi.resize_nearest_neighbor": 3, + "nnapi.quantized_lstm": 4, + "nnapi.if": 4, + "nnapi.while": 4, + "nnapi.elu": 4, + "nnapi.hard_swish": 4, + "nnapi.fill": 4, + "nnapi.rank": 4, + "nnapi.batch_matmul": 6, + "nnapi.pack": 6, + "nnapi.mirror_pad": 7, + "nnapi.reverse": 7, + } + return levels[pattern_name] + + +def partition_for_nnapi(mod: IRModule, feature_level: Optional[int] = None) -> IRModule: + """Partition the graph greedily offloading supported operators to NNAPI. + + Parameters + ---------- + mod : tvm.ir.IRModule + The module to run passes on. + feature_level : Optional[int] + The maximum NNAPI feature level. + + Returns + ------- + mod : tvm.ir.IRModule + Annotated and partitioned module. + """ + patterns = get_patterns_with_prefix("nnapi") + if feature_level is not None: + patterns = [pat for pat in patterns if feature_level >= min_feature_level(pat.name)] + mod = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=False)(mod) + mod = MergeCompositeFunctions()(mod) + return mod diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 8227530f7ab7..8b919d2c9dca 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -980,6 +980,12 @@ def _multi_gpu_exists(): target_kind_enabled="opencl", ) +# Mark a test as requiring NNAPI support in build. +requires_nnapi = Feature( + "NNAPI", + "NNAPI", + cmake_flag="USE_NNAPI_CODEGEN", +) # Mark a test as requiring microTVM to run requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc new file mode 100644 index 000000000000..ef74cca70ee8 --- /dev/null +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../transform/utils.h" +#include "../codegen_json/codegen_json.h" +#include "tvm/relax/attrs/manipulate.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONSerializer = backend::contrib::JSONSerializer; +using JSONGraphNode = backend::contrib::JSONGraphNode; +using JSONGraphNodeEntry = backend::contrib::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using NodeEntries = backend::contrib::NodeEntries; + +class NNAPIJSONSerializer; + +class CollectFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectFromCompositeFunctionBody(NNAPIJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const CallNode* call_node) override; + + void SetPermuteDimsAttribute(const CallNode* call_node) { + const auto* permute_dims_attr = call_node->attrs.as(); + ICHECK(permute_dims_attr); + if (permute_dims_attr->axes) { + std::vector axes; + for (auto axis : permute_dims_attr->axes.value()) { + axes.push_back(std::to_string(axis.IntValue())); + } + + std::vector axes_attr; + axes_attr.emplace_back(axes); + node_->SetAttr("axes", axes_attr); + } + } + + void SetAstypeAttribute(const CallNode* call_node) { + const auto* astype_attrs = call_node->attrs.as(); + ICHECK(astype_attrs); + + std::vector dtype_attr; + auto dtype_str = runtime::DLDataType2String(astype_attrs->dtype); + dtype_attr.emplace_back(std::vector{dtype_str}); + node_->SetAttr("astype_dtype", dtype_attr); + } + + void SetMeanAttribute(const CallNode* call_node) { + const auto* mean_attrs = call_node->attrs.as(); + ICHECK(mean_attrs); + ICHECK(mean_attrs->axis.defined()); + + { + std::vector axis; + for (auto dim : mean_attrs->axis.value()) { + axis.push_back(std::to_string(dim->value)); + } + + std::vector axis_attr; + axis_attr.emplace_back(axis); + node_->SetAttr("axis", axis_attr); + } + + { + const std::vector keepdims{mean_attrs->keepdims ? "1" : "0"}; + std::vector keepdims_attr; + keepdims_attr.emplace_back(keepdims); + node_->SetAttr("keepdims", keepdims_attr); + } + } + + void SetConv2dAttribute(const CallNode* call_node) { + const auto* conv2d_attr = call_node->attrs.as(); + ICHECK(conv2d_attr) << "didn't catch attributes"; + + std::vector strides; + if (!conv2d_attr->strides.empty()) { + for (auto stride : conv2d_attr->strides) { + const auto* stride_val = stride.as(); + ICHECK(stride_val) << "convertion failed"; + + strides.push_back(std::to_string(stride_val->value)); + } + } else { + strides = {"1", "1"}; + } + + std::vector padding; + for (auto pad : conv2d_attr->padding) { + const auto* padding_val = pad.as(); + + padding.push_back(std::to_string(padding_val->value)); + } + + std::vector groups; + const int group_val = conv2d_attr->groups; + groups.push_back(std::to_string(group_val)); + + std::vector strides_attr; + strides_attr.emplace_back(strides); + node_->SetAttr("strides", strides_attr); + + std::vector padding_attr; + padding_attr.emplace_back(padding); + node_->SetAttr("padding", padding_attr); + + std::vector group_attr; + group_attr.emplace_back(groups); + node_->SetAttr("group", group_attr); + } + + void SetMaxPool2dAttribute(const CallNode* call_node) { + const auto* max_pool_2d_attr = call_node->attrs.as(); + ICHECK(max_pool_2d_attr) << "didn't catch attributes"; + + std::vector strides; + if (!max_pool_2d_attr->strides.empty()) { + for (auto stride : max_pool_2d_attr->strides) { + const auto* stride_val = stride.as(); + ICHECK(stride_val) << "convertion failed"; + + strides.push_back(std::to_string(stride_val->value)); + } + } else { + strides.push_back("1"); + strides.push_back("1"); + } + + std::vector padding; + for (auto pad : max_pool_2d_attr->padding) { + const auto* padding_val = pad.as(); + + padding.push_back(std::to_string(padding_val->value)); + } + + std::vector pool_size; + for (auto size : max_pool_2d_attr->pool_size) { + const auto* pooling_val = size.as(); + + pool_size.push_back(std::to_string(pooling_val->value)); + } + + std::vector strides_attr; + strides_attr.emplace_back(strides); + node_->SetAttr("strides", strides_attr); + + std::vector padding_attr; + padding_attr.emplace_back(padding); + node_->SetAttr("padding", padding_attr); + + std::vector pooling_attr; + pooling_attr.emplace_back(pool_size); + node_->SetAttr("pool_size", pooling_attr); + } + + NNAPIJSONSerializer* serializer_; + JSONGraphObjectPtr node_; +}; + +class NNAPIJSONSerializer : public JSONSerializer { + public: + explicit NNAPIJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + CollectFromCompositeFunctionBody collector(this); + collector.VisitExpr(fn->body); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + node->CaptureAttrs(*collector.node_); + + VLOG(1) << "Adding node " << composite_name << " with " << node->GetInputs().size() + << " inputs"; + return AddNode(node, GetRef(call_node)); + } + + private: + Map bindings_; +}; + +void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + const auto* op_node = call_node->op.as(); + ICHECK(op_node != nullptr); + std::string name = op_node->name; + if (name == "relax.permute_dims") { + SetPermuteDimsAttribute(call_node); + } else if (name == "relax.astype") { + SetAstypeAttribute(call_node); + } else if (name == "relax.mean") { + SetMeanAttribute(call_node); + } else if (name == "relax.nn.conv2d") { + SetConv2dAttribute(call_node); + } else if (name == "relax.nn.max_pool2d") { + SetMaxPool2dAttribute(call_node); + } else { + } + ExprVisitor::VisitExpr_(call_node); +} + +Array NNAPICompiler(Array functions, Map /*unused*/, + Map constant_names) { + VLOG(1) << "NNAPI Compiler"; + + Array compiled_functions; + for (const auto& func : functions) { + NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.nnapi_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find NNAPI runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.nnapi").set_body_typed(NNAPICompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/nnapi/nnapi_builder.cc b/src/runtime/contrib/nnapi/nnapi_builder.cc new file mode 100644 index 000000000000..d43f00661de9 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_builder.cc @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include "nnapi_builder.h" + +#include +#include + +#include +#include +#include + +#include "../json/json_runtime.h" +#include "nnapi_ops.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +WrappedANeuralNetworksOperandType::WrappedANeuralNetworksOperandType( + int32_t tensor_type, std::vector dimensions, float scale, int32_t zero_point) + : dimensions_(dimensions) { + ty_.type = tensor_type; + if (dimensions_.empty()) { + ty_.dimensions = nullptr; + } else { + ty_.dimensions = dimensions_.data(); + } + ty_.dimensionCount = dimensions_.size(); + ty_.scale = scale; + ty_.zeroPoint = zero_point; +} + +WrappedANeuralNetworksOperandType::WrappedANeuralNetworksOperandType( + const WrappedANeuralNetworksOperandType& other) + : dimensions_(other.dimensions_), ty_(other.ty_) { + if (dimensions_.empty()) { + ty_.dimensions = nullptr; + } else { + ty_.dimensions = dimensions_.data(); + } +} + +WrappedANeuralNetworksOperandType& WrappedANeuralNetworksOperandType::operator=( + const WrappedANeuralNetworksOperandType& other) { + WrappedANeuralNetworksOperandType temp(other); + std::swap(*this, temp); + return *this; +} + +const ANeuralNetworksOperandType* WrappedANeuralNetworksOperandType::Get() const { return &ty_; } + +NNAPIOperand::NNAPIOperand(uint32_t index, const DLTensor* tensor) + : index_(index), scalar_(false), dimensions_(tensor->shape, tensor->shape + tensor->ndim) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } + + tensor_type_ = TensorTypeFromDLDataType(tensor->dtype); + scale_ = 0.0; + zero_point_ = 0; +} + +NNAPIOperand::NNAPIOperand(uint32_t index, const int64_t* shape, int ndim, DLDataType dtype) + : index_(index), scalar_(false), dimensions_(shape, shape + ndim) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } + + tensor_type_ = TensorTypeFromDLDataType(dtype); + scale_ = 0.0; + zero_point_ = 0; +} + +NNAPIOperand::NNAPIOperand(uint32_t index, int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point) + : index_(index), + scalar_(false), + tensor_type_(tensor_type), + dimensions_(dimensions), + scale_(scale), + zero_point_(zero_point) { + if (dimensions_.size() == 0) { + dimensions_.push_back(1); + } +} + +NNAPIOperand NNAPIOperand::Scalar(uint32_t index, int32_t tensor_type, + std::vector dimensions, float scale, + int32_t zero_point) { + NNAPIOperand operand(index, tensor_type, dimensions, scale, zero_point); + operand.dimensions_.clear(); + operand.scalar_ = true; + return operand; +} + +void NNAPIOperand::SetDimensions(std::vector dimensions) { dimensions_ = dimensions; } + +WrappedANeuralNetworksOperandType NNAPIOperand::GetOperandType() const { + std::vector dimensions(dimensions_.begin(), dimensions_.end()); + return WrappedANeuralNetworksOperandType(tensor_type_, dimensions, scale_, zero_point_); +} + +uint32_t NNAPIOperand::GetOperandIndex() const { return index_; } + +const std::vector& NNAPIOperand::GetDimensions() const { return dimensions_; } +const float NNAPIOperand::GetScale() const { return scale_; } +const int32_t NNAPIOperand::GetZeroPoint() const { return zero_point_; } + +int32_t NNAPIOperand::GetTensorType() const { return tensor_type_; } +bool NNAPIOperand::IsDynamicShape() const { + return std::any_of(dimensions_.begin(), dimensions_.end(), [](int64_t dim) { return dim == -1; }); +} + +NNAPIModelBuilder::NNAPIModelBuilder() { + ICHECK_EQ(ANeuralNetworksModel_create(&model_), ANEURALNETWORKS_NO_ERROR); +} + +NNAPIModelBuilder::~NNAPIModelBuilder() { ANeuralNetworksModel_free(model_); } + +NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(const DLTensor& tensor) { + NNAPIOperand operand(next_operand_index_++, &tensor); + const size_t operand_data_size = GetDataSize(tensor); + + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), tensor.data, + operand_data_size), + ANEURALNETWORKS_NO_ERROR); + + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(int32_t tensor_type, + std::vector dimensions, float scale, + int32_t zero_point, const void* buffer, + size_t size) { + NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateScalarOperandWithValue(OperandCode operand_code, + const void* buffer, size_t size) { + NNAPIOperand operand = NNAPIOperand::Scalar(next_operand_index_++, operand_code, {}, 0.0f, 0); + + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(const DLTensor& tensor) { + NNAPIOperand operand(next_operand_index_++, tensor.shape, tensor.ndim, tensor.dtype); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(const int64_t* shape, int ndim, DLDataType dtype) { + NNAPIOperand operand(next_operand_index_++, shape, ndim, dtype); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +NNAPIOperand NNAPIModelBuilder::CreateOperand(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point) { + NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); + ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + return operand; +} + +void NNAPIModelBuilder::AddOperation(ANeuralNetworksOperationType operation, + const std::vector input_indicies, + const std::vector output_indicies) { + ICHECK_EQ(ANeuralNetworksModel_addOperation(model_, operation, input_indicies.size(), + input_indicies.data(), output_indicies.size(), + output_indicies.data()), + ANEURALNETWORKS_NO_ERROR); +} + +void NNAPIModelBuilder::Finish(const std::vector& model_input_operands, + const std::vector& model_output_operands) { + const auto model_input_indices = ExtractOperandIndices(model_input_operands); + const auto model_output_indices = ExtractOperandIndices(model_output_operands); + ICHECK_EQ(ANeuralNetworksModel_identifyInputsAndOutputs( + model_, model_input_indices.size(), model_input_indices.data(), + model_output_indices.size(), model_output_indices.data()), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksModel_finish(model_), ANEURALNETWORKS_NO_ERROR); +} + +ANeuralNetworksCompilation* NNAPIModelBuilder::Compile() { + ANeuralNetworksCompilation* compilation; + ICHECK_EQ(ANeuralNetworksCompilation_create(model_, &compilation), ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksCompilation_setPreference(compilation, + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR); + return compilation; +} + +int32_t TensorTypeFromDLDataType(DLDataType ty) { + if (ty.code == kDLInt) { + if (ty.bits == 32) { + return ANEURALNETWORKS_TENSOR_INT32; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + } + } else if (ty.code == kDLUInt) { + if (ty.bits == 1) { + return ANEURALNETWORKS_TENSOR_BOOL8; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI unsigned integer tensor"; + } + } else if (ty.code == kDLFloat) { + if (ty.bits == 32) { + return ANEURALNETWORKS_TENSOR_FLOAT32; + } else if (ty.bits == 16) { + return ANEURALNETWORKS_TENSOR_FLOAT16; + } else { + ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + } + } else { + ICHECK(false) << "Unsupported DLDataTypeCode for NNAPI: " << ty.code; + } +} + +std::vector ExtractOperandIndices(const std::vector& operands) { + std::vector indices; + indices.reserve(operands.size()); + std::transform(operands.begin(), operands.end(), std::back_inserter(indices), + [](const NNAPIOperand& operand) { return operand.GetOperandIndex(); }); + return indices; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_GRAPH_EXECUTOR_NNAPI diff --git a/src/runtime/contrib/nnapi/nnapi_builder.h b/src/runtime/contrib/nnapi/nnapi_builder.h new file mode 100644 index 000000000000..4360f50bf1e9 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_builder.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ +#define TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +class WrappedANeuralNetworksOperandType { + public: + WrappedANeuralNetworksOperandType(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point); + WrappedANeuralNetworksOperandType(const WrappedANeuralNetworksOperandType&); + WrappedANeuralNetworksOperandType& operator=(const WrappedANeuralNetworksOperandType&); + + const ANeuralNetworksOperandType* Get() const; + + private: + std::vector dimensions_; + ANeuralNetworksOperandType ty_; +}; + +class NNAPIOperand { + public: + NNAPIOperand(uint32_t index, const DLTensor* tensor); + NNAPIOperand(uint32_t index, const int64_t* shape, int ndim, DLDataType dtype); + NNAPIOperand(uint32_t index, int32_t tensor_type, std::vector dimensions, float scale, + int32_t zero_point); + static NNAPIOperand Scalar(uint32_t index, int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point); + void SetDimensions(std::vector dimensions); + + WrappedANeuralNetworksOperandType GetOperandType() const; + uint32_t GetOperandIndex() const; + const std::vector& GetDimensions() const; + const float GetScale() const; + const int32_t GetZeroPoint() const; + int32_t GetTensorType() const; + bool IsDynamicShape() const; + + private: + uint32_t index_; + bool scalar_; + + // The NNAPI operand type e.g. ANEURALNETWORKS_TENSOR_INT32. + int32_t tensor_type_; + std::vector dimensions_; + float scale_; + int32_t zero_point_; +}; + +class NNAPIModelBuilder { + public: + NNAPIModelBuilder(); + ~NNAPIModelBuilder(); + NNAPIModelBuilder(const NNAPIModelBuilder&) = delete; + NNAPIModelBuilder& operator=(const NNAPIModelBuilder&) = delete; + inline NNAPIModelBuilder(NNAPIModelBuilder&& other) { + model_ = other.model_; + other.model_ = nullptr; + next_operand_index_ = other.next_operand_index_; + other.next_operand_index_ = 0; + } + inline NNAPIModelBuilder& operator=(NNAPIModelBuilder&& other) { + model_ = other.model_; + other.model_ = nullptr; + next_operand_index_ = other.next_operand_index_; + other.next_operand_index_ = 0; + return *this; + } + + NNAPIOperand CreateOperandWithValue(const DLTensor& tensor); + NNAPIOperand CreateOperandWithValue(int32_t tensor_type, std::vector dimensions, + float scale, int32_t zero_point, const void* buffer, + size_t size); + NNAPIOperand CreateScalarOperandWithValue(OperandCode operand_code, const void* buffer, + size_t size); + + NNAPIOperand CreateOperand(const DLTensor& tensor); + NNAPIOperand CreateOperand(const int64_t* shape, int ndim, DLDataType dtype); + NNAPIOperand CreateOperand(int32_t tensor_type, std::vector dimensions, float scale, + int32_t zero_point); + + void AddOperation(ANeuralNetworksOperationType operation, + const std::vector input_indices, + const std::vector output_indices); + + void Finish(const std::vector& model_input_operands, + const std::vector& model_output_operands); + ANeuralNetworksCompilation* Compile(); + + private: + ANeuralNetworksModel* model_; + uint32_t next_operand_index_ = 0; +}; + +/*! + * \brief Convert a DLDataType to an NNAPI OperandCode. + */ +int32_t TensorTypeFromDLDataType(DLDataType ty); + +std::vector ExtractOperandIndices(const std::vector& operands); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_GRAPH_EXECUTOR_NNAPI +#endif // TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_BUILDER_H_ diff --git a/src/runtime/contrib/nnapi/nnapi_ops.cc b/src/runtime/contrib/nnapi/nnapi_ops.cc new file mode 100644 index 000000000000..ad055ec2c76f --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_ops.cc @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI +#include "nnapi_ops.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nnapi_builder.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +NNAPIOpConverterParams::NNAPIOpConverterParams(const JSONGraphNode& node) : node(node) {} + +NNAPIOpConverter::NNAPIOpConverter(std::string op_name) : op_name_(op_name) {} + +void ElwBinaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + // A map from op names to NNAPI OperationCode and whether it requires a FuseCode. + static const std::unordered_map> + op_map = { + {"add", {ANEURALNETWORKS_ADD, true}}, + {"mul", {ANEURALNETWORKS_MUL, true}}, + {"div", {ANEURALNETWORKS_DIV, true}}, + {"sub", {ANEURALNETWORKS_SUB, true}}, + {"pow", {ANEURALNETWORKS_POW, false}}, + {"equal", {ANEURALNETWORKS_EQUAL, false}}, + {"greater", {ANEURALNETWORKS_GREATER, false}}, + {"greater_equal", {ANEURALNETWORKS_GREATER_EQUAL, false}}, + {"less", {ANEURALNETWORKS_LESS, false}}, + {"less_equal", {ANEURALNETWORKS_LESS_EQUAL, false}}, + {"not_equal", {ANEURALNETWORKS_NOT_EQUAL, false}}, + {"maximum", {ANEURALNETWORKS_MAXIMUM, false}}, + {"minimum", {ANEURALNETWORKS_MINIMUM, false}}, + }; + + auto it = op_map.find(op_name_); + ICHECK(it != op_map.end()) << "Unsupported binary operation type " << op_name_; + const ANeuralNetworksOperationType operation_type = std::get<0>(it->second); + const bool requires_fuse_code = std::get<1>(it->second); + + ICHECK_EQ(inputs.size(), 2) << "Expected binary operation to have 2 inputs but got " + << inputs.size(); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + if (requires_fuse_code) { + // Create an extra input at index 2 for the fuse code. + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + } + + builder.AddOperation(operation_type, input_indices, output_indices); +} + +void UnaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + static const std::unordered_map op_map = { + // clang-format off + {"floor", ANEURALNETWORKS_FLOOR}, + {"logistic", ANEURALNETWORKS_LOGISTIC}, + {"relu", ANEURALNETWORKS_RELU}, + {"tanh", ANEURALNETWORKS_TANH}, + {"abs", ANEURALNETWORKS_ABS}, + {"exp", ANEURALNETWORKS_EXP}, + {"log", ANEURALNETWORKS_LOG}, + {"neg", ANEURALNETWORKS_NEG}, + {"sqrt", ANEURALNETWORKS_SQRT}, + {"rsqrt", ANEURALNETWORKS_RSQRT}, + // clang-format on + }; + auto it = op_map.find(op_name_); + ICHECK(it != op_map.end()) << "Unsupported unary operation type " << op_name_; + const ANeuralNetworksOperationType operation_type = it->second; + + const auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + builder.AddOperation(operation_type, input_indices, output_indices); +} + +void SoftmaxOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 1) << "Unsupported number of inputs for NNAPI softmax operation: " + << inputs.size(); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + // Add the scalar input for beta value at index 1. + const auto& input = inputs[0]; + // TODO(PLLab): Conditionally use float16 beta for float16 input. + ICHECK_EQ(input.GetTensorType(), ANEURALNETWORKS_TENSOR_FLOAT32) + << "NNAPI runtime does not support non-float32 inputs for softmax yet"; + const float beta = 1.0f; + const NNAPIOperand beta_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_FLOAT32, &beta, sizeof beta); + input_indices.push_back(beta_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, output_indices); +} + +// Insert a reshape operation that reshapes `operand` to `dimensions` and return the reshaped +// operand. +NNAPIOperand ReshapeOperand(NNAPIModelBuilder& builder, const NNAPIOperand& operand, // NOLINT(*) + std::vector dimensions) { + // ANEURALNETWORKS_RESHAPE requires the dimensions to be specified in a int32 tensor. + const std::vector dimensions_int32(dimensions.begin(), dimensions.end()); + const std::vector dim_of_dims{static_cast(dimensions_int32.size())}; + + const NNAPIOperand reshape_shape_operand = + builder.CreateOperandWithValue(ANEURALNETWORKS_TENSOR_INT32, dim_of_dims, 0.0f, 0, + reinterpret_cast(dimensions_int32.data()), + dimensions_int32.size() * sizeof(*dimensions_int32.data())); + const NNAPIOperand reshaped_operand = builder.CreateOperand( + operand.GetTensorType(), dimensions, operand.GetScale(), operand.GetZeroPoint()); + + builder.AddOperation( + ANEURALNETWORKS_RESHAPE, + std::vector{operand.GetOperandIndex(), reshape_shape_operand.GetOperandIndex()}, + std::vector{reshaped_operand.GetOperandIndex()}); + return reshaped_operand; +} + +NNAPIOperand TransposeOperand(NNAPIModelBuilder& builder, const NNAPIOperand& operand, // NOLINT(*) + std::vector dimensions) { + const std::vector dimensions_int32(dimensions.begin(), dimensions.end()); + const std::vector dim_of_axes{static_cast(dimensions_int32.size())}; + std::vector result_dimension; + for (size_t i = 0; i < dimensions.size(); i++) { + result_dimension.push_back(operand.GetDimensions()[dimensions_int32[i]]); + } + + const NNAPIOperand transpose_shape_operand = + builder.CreateOperandWithValue(ANEURALNETWORKS_TENSOR_INT32, dim_of_axes, 0.0f, 0, + reinterpret_cast(dimensions_int32.data()), + dimensions_int32.size() * sizeof(*dimensions_int32.data())); + const NNAPIOperand transposed_operand = builder.CreateOperand( + operand.GetTensorType(), result_dimension, operand.GetScale(), operand.GetZeroPoint()); + + builder.AddOperation( + ANEURALNETWORKS_TRANSPOSE, + std::vector{operand.GetOperandIndex(), transpose_shape_operand.GetOperandIndex()}, + std::vector{transposed_operand.GetOperandIndex()}); + + return transposed_operand; +} + +void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 2); + + auto input_indices = ExtractOperandIndices(inputs); + const auto output_indices = ExtractOperandIndices(outputs); + + const size_t input0_ndim = inputs[0].GetDimensions().size(); + const size_t input1_ndim = inputs[1].GetDimensions().size(); + if (input0_ndim != input1_ndim) { + if (input0_ndim > input1_ndim) { + // Check that the extra leading dimensions on input 0 are all ones. + const size_t diff = input0_ndim - input1_ndim; + for (size_t i = 0; i < diff; ++i) { + ICHECK_EQ(inputs[0].GetDimensions()[i], 1); + } + + // Expand input 1's dimensions. + std::vector reshaped_dimensions(diff, 1); + reshaped_dimensions.insert(reshaped_dimensions.end(), inputs[1].GetDimensions().begin(), + inputs[1].GetDimensions().end()); + const auto reshaped_operand = ReshapeOperand(builder, inputs[1], reshaped_dimensions); + input_indices[1] = reshaped_operand.GetOperandIndex(); + } else { + // input0_ndim < input1_ndim + // Check that the extra leading dimensions on input 1 are all ones. + const size_t diff = input1_ndim - input0_ndim; + for (size_t i = 0; i < diff; ++i) { + ICHECK_EQ(inputs[1].GetDimensions()[i], 1); + } + + // Expand input 0's dimensions. + std::vector reshaped_dimensions(diff, 1); + reshaped_dimensions.insert(reshaped_dimensions.end(), inputs[0].GetDimensions().begin(), + inputs[0].GetDimensions().end()); + const auto reshaped_operand = ReshapeOperand(builder, inputs[0], reshaped_dimensions); + input_indices[0] = reshaped_operand.GetOperandIndex(); + } + } + + { + const unsigned char adj_x = 0; + const NNAPIOperand adj_x_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &adj_x, sizeof(adj_x)); + input_indices.push_back(adj_x_operand.GetOperandIndex()); + } + + { + const unsigned char adj_y = 0; + const NNAPIOperand adj_y_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &adj_y, sizeof(adj_y)); + input_indices.push_back(adj_y_operand.GetOperandIndex()); + } + + builder.AddOperation(ANEURALNETWORKS_BATCH_MATMUL, input_indices, output_indices); +} + +void TransposeOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + ICHECK_EQ(inputs.size(), 1); + + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + std::vector axes; + if (node.HasAttr("axes")) { + const auto axes_attr = node.GetAttr>("axes"); + for (auto str_axis : axes_attr) { + axes.push_back(std::stoi(str_axis)); + } + } else { + for (size_t i = 0; i < inputs[0].GetDimensions().size(); ++i) { + axes.push_back(i); + } + std::reverse(axes.begin(), axes.end()); + } + + const std::vector dim_of_axes{static_cast(axes.size())}; + const NNAPIOperand perm_operand = builder.CreateOperandWithValue( + ANEURALNETWORKS_TENSOR_INT32, dim_of_axes, 0.0f, 0, + reinterpret_cast(axes.data()), axes.size() * sizeof(*axes.data())); + input_indices.push_back(perm_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, output_indices); +} + +void CastOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // Extract the dtype attribute and check that the output operand type matches the dtype specified. + const auto dtype_attr = node.GetAttr>("astype_dtype"); + ICHECK(dtype_attr.size() == 1); + const auto dtype_str = dtype_attr[0]; + const DLDataType dtype = runtime::String2DLDataType(dtype_str); + ICHECK(outputs.size() == 1); + const auto output_tensor_type = outputs[0].GetTensorType(); + ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) + << "Expect a cast to dtype " << dtype_str << " but got output operand of type " + << output_tensor_type; + + builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, output_indices); +} + +template +NNAPIOperand CreateConv2DBiasOperand(NNAPIModelBuilder& builder, // NOLINT(*) + int64_t output_depth) { + std::vector bias(output_depth, 0.0f); + + const std::vector dim_of_bias{static_cast(bias.size())}; + const NNAPIOperand bias_operand = builder.CreateOperandWithValue( + TensorType, dim_of_bias, 0.0f, 0, reinterpret_cast(bias.data()), + bias.size() * sizeof(*bias.data())); + return bias_operand; +} + +void Conv2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + ICHECK(inputs.size() >= 2); + const auto input_tensor_type = inputs[0].GetTensorType(); + const auto filter_tensor_type = inputs[1].GetTensorType(); + ICHECK(input_tensor_type == filter_tensor_type); + ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + + // transpose kernel + std::vector transposed_dimensions{0, 2, 3, 1}; + const auto transposed_operand = TransposeOperand(builder, inputs[1], transposed_dimensions); + + input_indices[1] = transposed_operand.GetOperandIndex(); + + // bias operand + if (input_indices.size() == 2) { + const int output_depth = inputs[1].GetDimensions()[0]; + if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } else if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } + } else { + int64_t bias_dim; + for (int i = 0; i < inputs[2].GetDimensions().size(); i++) { + if (inputs[2].GetDimensions()[i] != 1) { + bias_dim = inputs[2].GetDimensions()[i]; + } + } + std::vector bias_dimension = {bias_dim}; + NNAPIOperand bias_operand = ReshapeOperand(builder, inputs[2], bias_dimension); + input_indices[2] = bias_operand.GetOperandIndex(); + } + // padding operand + std::vector padding; + const auto padding_attr = node.GetAttr>("padding"); + + for (auto str_pad : padding_attr) { + padding.push_back(std::stoi(str_pad)); + } + + ICHECK(padding.size() == 4) << "NNAPI runtime currently only supports 4-way padding for Conv2D"; + const NNAPIOperand padding_left_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[1], sizeof(padding[1])); + input_indices.push_back(padding_left_operand.GetOperandIndex()); + + const NNAPIOperand padding_right_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[3], sizeof(padding[3])); + input_indices.push_back(padding_right_operand.GetOperandIndex()); + + const NNAPIOperand padding_top_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[0], sizeof(padding[0])); + input_indices.push_back(padding_top_operand.GetOperandIndex()); + + const NNAPIOperand padding_bottom_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[2], sizeof(padding[2])); + input_indices.push_back(padding_bottom_operand.GetOperandIndex()); + + // stride operand + std::vector stride; + const auto stride_attr = node.GetAttr>("strides"); + for (auto str_stride : stride_attr) { + stride.push_back(std::stoi(str_stride)); + } + + ICHECK(stride.size() == 2); + const NNAPIOperand stride_width_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[0], sizeof(stride[0])); + input_indices.push_back(stride_width_operand.GetOperandIndex()); + + const NNAPIOperand stride_height_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[1], sizeof(stride[1])); + input_indices.push_back(stride_height_operand.GetOperandIndex()); + + // group + int32_t group; + const auto group_attr = node.GetAttr>("group"); + for (auto str_group : group_attr) { + group = std::stoi(str_group); + } + + if (group > 1) { + const NNAPIOperand group_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &group, sizeof(group)); + input_indices.push_back(group_operand.GetOperandIndex()); + } + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + // layout + // Use NCHW layout for input 0 and output 0. + const bool layout = true; + const NNAPIOperand layout_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &layout, sizeof(layout)); + input_indices.push_back(layout_operand.GetOperandIndex()); + + if (group > 1) { + builder.AddOperation(ANEURALNETWORKS_GROUPED_CONV_2D, input_indices, output_indices); + } else { + builder.AddOperation(ANEURALNETWORKS_CONV_2D, input_indices, output_indices); + } +} + +void MaxPool2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // padding operand + std::vector padding; + const auto padding_attr = node.GetAttr>("padding"); + + for (auto str_pad : padding_attr) { + padding.push_back(std::stoi(str_pad)); + } + + const NNAPIOperand padding_left_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[1], sizeof(padding[1])); + input_indices.push_back(padding_left_operand.GetOperandIndex()); + + const NNAPIOperand padding_right_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[3], sizeof(padding[3])); + input_indices.push_back(padding_right_operand.GetOperandIndex()); + + const NNAPIOperand padding_top_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[0], sizeof(padding[0])); + input_indices.push_back(padding_top_operand.GetOperandIndex()); + + const NNAPIOperand padding_bottom_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[2], sizeof(padding[2])); + input_indices.push_back(padding_bottom_operand.GetOperandIndex()); + + // stride operand + std::vector stride; + const auto stride_attr = node.GetAttr>("strides"); + for (auto str_stride : stride_attr) { + stride.push_back(std::stoi(str_stride)); + } + + const NNAPIOperand stride_width_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[0], sizeof(stride[0])); + input_indices.push_back(stride_width_operand.GetOperandIndex()); + + const NNAPIOperand stride_height_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[1], sizeof(stride[1])); + input_indices.push_back(stride_height_operand.GetOperandIndex()); + + // filter operand + std::vector pool_size; + const auto pool_size_attr = node.GetAttr>("pool_size"); + for (auto size : pool_size_attr) { + pool_size.push_back(std::stoi(size)); + } + + const NNAPIOperand pool_size_width_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &pool_size[0], sizeof(pool_size[0])); + input_indices.push_back(pool_size_width_operand.GetOperandIndex()); + + const NNAPIOperand pool_size_height_operand = builder.CreateScalarOperandWithValue( + ANEURALNETWORKS_INT32, &pool_size[1], sizeof(pool_size[1])); + input_indices.push_back(pool_size_height_operand.GetOperandIndex()); + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + // layout + const bool layout = true; + const NNAPIOperand layout_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_BOOL, &layout, sizeof(layout)); + input_indices.push_back(layout_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_MAX_POOL_2D, input_indices, output_indices); +} + +void DenseOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + const auto input_tensor_type = inputs[0].GetTensorType(); + const auto filter_tensor_type = inputs[1].GetTensorType(); + ICHECK(input_tensor_type == filter_tensor_type); + ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + + if (input_indices.size() == 2) { + const int output_depth = inputs[1].GetDimensions()[0]; + if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } else if (input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16) { + const NNAPIOperand bias_operand = + CreateConv2DBiasOperand(builder, output_depth); + input_indices.push_back(bias_operand.GetOperandIndex()); + } + } + + // fuse code + const int32_t fused_none = ANEURALNETWORKS_FUSED_NONE; + const NNAPIOperand fuse_code_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &fused_none, sizeof(fused_none)); + input_indices.push_back(fuse_code_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices, output_indices); +} + +void MeanOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const { + auto input_indices = ExtractOperandIndices(inputs); + auto output_indices = ExtractOperandIndices(outputs); + + // Extract the axis attribute and create an operand for it. + const auto axis_attr = node.GetAttr>("axis"); + std::vector axis; + for (auto dim : axis_attr) { + axis.push_back(std::stoi(dim)); + } + const std::vector dim_of_axis{static_cast(axis.size())}; + + const NNAPIOperand axis_operand = builder.CreateOperandWithValue( + ANEURALNETWORKS_TENSOR_INT32, dim_of_axis, 0.0f, 0, + reinterpret_cast(axis.data()), axis.size() * sizeof(*axis.data())); + input_indices.push_back(axis_operand.GetOperandIndex()); + + // Extract the keepdims attribute and create an operand for it. + const auto keepdims_attr = node.GetAttr>("keepdims"); + ICHECK(keepdims_attr.size() == 1); + const int32_t keepdims = keepdims_attr[0] == "1"; + + const NNAPIOperand keepdims_operand = + builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &keepdims, sizeof keepdims); + input_indices.push_back(keepdims_operand.GetOperandIndex()); + + builder.AddOperation(ANEURALNETWORKS_MEAN, input_indices, output_indices); +} + +const std::unordered_map>& GetOpConverters() { + static const std::unordered_map> map = []() { + std::unordered_map> map; + map.emplace("nnapi.add", std::make_unique("add")); + map.emplace("nnapi.mul", std::make_unique("mul")); + map.emplace("nnapi.div", std::make_unique("div")); + map.emplace("nnapi.sub", std::make_unique("sub")); + map.emplace("nnapi.pow", std::make_unique("pow")); + map.emplace("nnapi.equal", std::make_unique("equal")); + map.emplace("nnapi.greater", std::make_unique("greater")); + map.emplace("nnapi.greater_equal", std::make_unique("greater_equal")); + map.emplace("nnapi.less", std::make_unique("less")); + map.emplace("nnapi.less_equal", std::make_unique("less_equal")); + map.emplace("nnapi.not_equal", std::make_unique("not_equal")); + map.emplace("nnapi.maximum", std::make_unique("maximum")); + map.emplace("nnapi.minimum", std::make_unique("minimum")); + map.emplace("nnapi.floor", std::make_unique("floor")); + map.emplace("nnapi.logistic", std::make_unique("logistic")); + map.emplace("nnapi.relu", std::make_unique("relu")); + map.emplace("nnapi.tanh", std::make_unique("tanh")); + map.emplace("nnapi.abs", std::make_unique("abs")); + map.emplace("nnapi.exp", std::make_unique("exp")); + map.emplace("nnapi.log", std::make_unique("log")); + map.emplace("nnapi.neg", std::make_unique("neg")); + map.emplace("nnapi.sqrt", std::make_unique("sqrt")); + map.emplace("nnapi.rsqrt", std::make_unique("rsqrt")); + map.emplace("nnapi.softmax", std::make_unique()); + map.emplace("nnapi.batch_matmul", std::make_unique()); + map.emplace("nnapi.transpose", std::make_unique()); + map.emplace("nnapi.cast", std::make_unique("cast")); + map.emplace("nnapi.mean", std::make_unique("mean")); + map.emplace("nnapi.conv2d", std::make_unique()); + map.emplace("nnapi.fully_connected", std::make_unique()); + map.emplace("nnapi.max_pool_2d", std::make_unique()); + return map; + }(); + return map; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_GRAPH_EXECUTOR_NNAPI diff --git a/src/runtime/contrib/nnapi/nnapi_ops.h b/src/runtime/contrib/nnapi/nnapi_ops.h new file mode 100644 index 000000000000..748a0b1d526c --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_ops.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ +#define TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + +#include + +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "nnapi_builder.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +struct NNAPIOpConverterParams { + const JSONGraphNode& node; + std::vector inputs; + std::vector outputs; + explicit NNAPIOpConverterParams(const JSONGraphNode& node); +}; + +class NNAPIOpConverter { + public: + std::string op_name_; + + explicit NNAPIOpConverter(std::string op_name); + virtual ~NNAPIOpConverter() = default; + + virtual void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, // NOLINT(*) + const std::vector& inputs, + std::vector& outputs) const = 0; // NOLINT(*) +}; + +class ElwBinaryOpConverter : public NNAPIOpConverter { + public: + inline explicit ElwBinaryOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~ElwBinaryOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class UnaryOpConverter : public NNAPIOpConverter { + public: + inline explicit UnaryOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~UnaryOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class SoftmaxOpConverter : public NNAPIOpConverter { + public: + inline SoftmaxOpConverter() : NNAPIOpConverter("softmax") {} + ~SoftmaxOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MatmulOpConverter : public NNAPIOpConverter { + public: + inline MatmulOpConverter() : NNAPIOpConverter("") {} + ~MatmulOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class TransposeOpConverter : public NNAPIOpConverter { + public: + inline TransposeOpConverter() : NNAPIOpConverter("") {} + ~TransposeOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class CastOpConverter : public NNAPIOpConverter { + public: + inline explicit CastOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~CastOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; +class Conv2dOpConverter : public NNAPIOpConverter { + public: + inline Conv2dOpConverter() : NNAPIOpConverter("") {} + ~Conv2dOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class DenseOpConverter : public NNAPIOpConverter { + public: + inline DenseOpConverter() : NNAPIOpConverter("") {} + ~DenseOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MaxPool2dOpConverter : public NNAPIOpConverter { + public: + inline MaxPool2dOpConverter() : NNAPIOpConverter("") {} + ~MaxPool2dOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +class MeanOpConverter : public NNAPIOpConverter { + public: + inline explicit MeanOpConverter(std::string op_name) : NNAPIOpConverter(op_name) {} + ~MeanOpConverter() = default; + + void Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, + const std::vector& inputs, + std::vector& outputs) const override; +}; + +const std::unordered_map>& GetOpConverters(); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_GRAPH_EXECUTOR_NNAPI +#endif // TVM_RUNTIME_CONTRIB_NNAPI_NNAPI_OPS_H_ diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc new file mode 100644 index 000000000000..c63098873da1 --- /dev/null +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI +#include +#include + +#include "nnapi_builder.h" +#include "nnapi_ops.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +class NNAPIRuntime : public JSONRuntimeBase { + public: + explicit NNAPIRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* type_key() const final { return "nnapi"; } + +#ifdef TVM_GRAPH_EXECUTOR_NNAPI + struct CompiledModel { + CompiledModel(NNAPIModelBuilder builder, ANeuralNetworksCompilation* compilation, + std::vector model_output_operands) + : builder(std::move(builder)), + compilation(compilation), + model_output_operands(model_output_operands) {} + NNAPIModelBuilder builder; + ANeuralNetworksCompilation* compilation; + std::vector model_output_operands; + }; + + std::optional compiled_model_; + + void Init(const Array& consts) final { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required constants."; + SetupConstants(consts); + CompileModel(); + } + + void CompileModel() { + NNAPIModelBuilder builder; + + // Clear the map, otherwise the input shapes from last inference gets used. + node_output_map_.clear(); + + // Add inputs as NNAPI model operands. + std::vector model_input_operands; + for (size_t i = 0; i < input_nodes_.size(); ++i) { + const uint32_t nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + const std::vector input_shape = nodes_[nid].GetOpShape()[j]; + const auto input_dtype = nodes_[nid].GetOpDataType()[j]; + const NNAPIOperand operand = + builder.CreateOperand(input_shape.data(), input_shape.size(), input_dtype); + node_output_map_.emplace(nid, operand); + model_input_operands.push_back(operand); + } + } + } + + // Add kernels as NNAPI operations. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() != "kernel") { + continue; + } + AddOperation(builder, nid, node); + } + + // Collect the output operands indices. + std::vector model_output_operands; + for (size_t i = 0; i < outputs_.size(); ++i) { + const auto& node = outputs_[i]; + auto it = node_output_map_.find(node.id_); + ICHECK(it != node_output_map_.end()) << "Missing model output."; + const auto& operand = it->second; + model_output_operands.push_back(operand); + } + + // Finish and compile the model. + builder.Finish(model_input_operands, model_output_operands); + ANeuralNetworksCompilation* compilation = builder.Compile(); + + // Store the compilation + compiled_model_.emplace(std::move(builder), compilation, model_output_operands); + } + + void ExecuteModel(ANeuralNetworksCompilation* compilation, + const std::vector& model_output_operands) { + // Execute the model. + ANeuralNetworksExecution* execution; + ICHECK_EQ(ANeuralNetworksExecution_create(compilation, &execution), ANEURALNETWORKS_NO_ERROR); + + for (size_t i = 0; i < input_nodes_.size(); ++i) { + const uint32_t nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + auto it = node_output_map_.find(nid); + ICHECK(it != node_output_map_.end()) << "Missing model input."; + const auto& operand = it->second; + + const uint32_t eid = EntryID(nid, j); + const auto entry = data_entry_[eid]; + + const auto operand_data_size = GetDataSize(*entry); + ICHECK_EQ(ANeuralNetworksExecution_setInput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); + } + } + } + + for (size_t i = 0; i < outputs_.size(); ++i) { + const auto& operand = model_output_operands[i]; + const auto& node = outputs_[i]; + + const auto eid = EntryID(node); + const auto entry = data_entry_[eid]; + + const auto operand_data_size = GetDataSize(*entry); + ICHECK_EQ(ANeuralNetworksExecution_setOutput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); + } + + ANeuralNetworksEvent* compute_event; + ICHECK_EQ(ANeuralNetworksExecution_startCompute(execution, &compute_event), + ANEURALNETWORKS_NO_ERROR); + ICHECK_EQ(ANeuralNetworksEvent_wait(compute_event), ANEURALNETWORKS_NO_ERROR); + ANeuralNetworksEvent_free(compute_event); + + ANeuralNetworksExecution_free(execution); + } + + void Run() final { + ICHECK(compiled_model_.has_value()); + CompiledModel& compiled_model = compiled_model_.value(); + ExecuteModel(compiled_model.compilation, compiled_model.model_output_operands); + } + + void AddOperation(NNAPIModelBuilder& builder, uint32_t nid, // NOLINT(*) + const JSONGraphNode& node) { + std::vector inputs; + std::vector outputs; + + // Map the op name to its converter. + const auto& converter_map = GetOpConverters(); + auto it = converter_map.find(node.GetOpName()); + ICHECK(it != converter_map.end()) << node.GetOpName() << ": Unsupported operation name"; + const NNAPIOpConverter& converter = *it->second; + + // Add input operands to params. + for (size_t i = 0; i < node.GetInputs().size(); ++i) { + auto in_node = node.GetInputs()[i]; + auto it = node_output_map_.find(in_node.id_); + ICHECK(it != node_output_map_.end()) << node.GetOpName() << ": Missing input"; + auto& operand = it->second; + inputs.push_back(operand); + } + + // Create and add output operands to params. + const auto output_shapes = node.GetOpShape(); + const auto output_dtypes = node.GetOpDataType(); + ICHECK(output_shapes.size() == output_dtypes.size()) + << "The number of output shapes must match the number of output dtypes"; + ICHECK(output_shapes.size() == 1) + << "NNAPI runtime currently does not support more than one output per operation yet"; + + for (size_t i = 0; i < output_shapes.size(); ++i) { + auto output_shape = output_shapes[i]; + const NNAPIOperand output_operand = + builder.CreateOperand(output_shape.data(), output_shape.size(), output_dtypes[i]); + outputs.push_back(output_operand); + } + + converter.Convert(builder, node, inputs, outputs); + + // Record the final output shape. + node_output_map_.emplace(nid, outputs[0]); + } + + private: + // Mapping from JSON node IDs to NNAPI operand numbers. + std::unordered_map node_output_map_; + +#else // ifdef TVM_GRAPH_EXECUTOR_NNAPI + void Init(const Array& consts) final { + LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + } + + void Run() final { + LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + } +#endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI +}; + +runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 73800338b143..2d1c33cbf282 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -279,6 +279,14 @@ #define TVM_INFO_USE_NVSHMEM "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_NNAPI_CODEGEN +#define TVM_INFO_USE_NNAPI_CODEGEN "NOT-FOUND" +#endif + +#ifndef TVM_INFO_USE_NNAPI_RUNTIME +#define TVM_INFO_USE_NNAPI_RUNTIME "NOT-FOUND" +#endif + namespace tvm { /*! @@ -392,6 +400,8 @@ TVM_DLL Map GetLibInfo() { {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, {"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM}, + {"USE_NNAPI_CODEGEN", TVM_INFO_USE_NNAPI_CODEGEN}, + {"USE_NNAPI_RUNTIME", TVM_INFO_USE_NNAPI_RUNTIME}, {"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT}, }; return result; diff --git a/tests/python/nightly/test_nnapi/__init__.py b/tests/python/nightly/test_nnapi/__init__.py new file mode 100644 index 000000000000..b2606427b1d8 --- /dev/null +++ b/tests/python/nightly/test_nnapi/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for NNAPI""" diff --git a/tests/python/nightly/test_nnapi/conftest.py b/tests/python/nightly/test_nnapi/conftest.py new file mode 100644 index 000000000000..abed80995a59 --- /dev/null +++ b/tests/python/nightly/test_nnapi/conftest.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import pytest + +from tvm import rpc + + +def remote(): + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "RPC_DEVICE_KEY" in os.environ + ): + + rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] + rpc_tracker_port = int(os.environ["TVM_TRACKER_PORT"]) + rpc_device_key = os.environ["RPC_DEVICE_KEY"] + tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) + remote = tracker.request(rpc_device_key, priority=0, session_timeout=600) + return remote, tracker + else: + return None diff --git a/tests/python/nightly/test_nnapi/infrastructure.py b/tests/python/nightly/test_nnapi/infrastructure.py new file mode 100644 index 000000000000..aa5580c375ae --- /dev/null +++ b/tests/python/nightly/test_nnapi/infrastructure.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +import tvm.script.relax as R + +# from tvm.contrib.debugger import debug_runtime as graph_executor +from tvm.contrib import ndk, utils +from tvm.relax.backend.contrib.nnapi import partition_for_nnapi + + +# pylint: disable=import-outside-toplevel,missing-function-docstring +def reshape_matmul(mod: tvm.IRModule): + from typing import Dict + + from tvm.relax import Expr + from tvm.relax.dpl import DFPattern, rewrite_call + from tvm.relax.dpl.pattern import is_op, wildcard + + input0 = wildcard() + input1 = wildcard() + pattern = is_op("relax.matmul")(input0, input1) + + def _rewriter(expr: Expr, matches: Dict[DFPattern, Expr]): + i0 = matches[input0] + i1 = matches[input1] + if len(i0.struct_info.shape) == 2 and len(i1.struct_info.shape) == 2: + i0_shape = [1] + [*i0.struct_info.shape.values] + i1_shape = [1] + [*i1.struct_info.shape.values] + oshape = matches[pattern].struct_info.shape + return R.reshape(R.matmul(R.reshape(i0, i0_shape), R.reshape(i1, i1_shape)), oshape) + return expr + + mod["main"] = rewrite_call(pattern, _rewriter, mod["main"]) + return mod + + +def decompose_clip(mod: tvm.IRModule) -> tvm.IRModule: + from typing import Dict + + from tvm.relax import Expr + from tvm.relax.dpl import DFPattern, rewrite_call + from tvm.relax.dpl.pattern import is_op, wildcard + + input_pattern = wildcard() + min_pattern = wildcard() + max_pattern = wildcard() + pattern = is_op("relax.clip")(input_pattern, min_pattern, max_pattern) + + def _rewriter( + expr: Expr, matches: Dict[DFPattern, Expr] + ) -> Expr: # pylint: disable=unused-argument + dtype = matches[input_pattern].struct_info.dtype + return R.minimum( + R.maximum( + matches[input_pattern], + R.const(np.array(matches[min_pattern].value.value).astype(dtype), dtype), + ), + R.const(np.array(matches[max_pattern].value.value).astype(dtype), dtype), + ) + + mod["main"] = rewrite_call(pattern, _rewriter, mod["main"]) + return mod + + +def _build(mod, enable_nnapi): + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + + if enable_nnapi: + mod = tvm.relax.transform.FoldConstant()(mod) + mod = reshape_matmul(mod) + mod = decompose_clip(mod) + mod = partition_for_nnapi(mod) + + mod = tvm.relax.transform.RunCodegen()(mod) + ex = tvm.relax.build(mod, target="llvm -mtriple=aarch64-linux-android") + + return ex + + +def _run(remote, tracker, ex, inputs): + + tmp = utils.tempdir() + so_name = "test_mod.so" + so_path = tmp / so_name + ex.export_library(str(so_path), fcompile=ndk.create_shared, options=["-shared", "-fPIC", "-lm"]) + + remote.upload(so_path) + dev = remote.cpu(0) + + try: + + # Execute the model on the remote. + remote_ex = remote.load_module(so_name) + vm = tvm.relax.VirtualMachine(remote_ex, device=dev) + + inputs = [x.copyto(dev) for x in inputs] + + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + output = vm.get_outputs("main") + output = output.numpy() + except Exception as e: + # Re-raise all exceptions + raise e + finally: + # Manually close the connection. + # See https://discuss.tvm.apache.org/t/trouble-with-rpc-session/14008/. + # + # TODO: Remove if it does not happen on Python 3.11. + remote._sess.get_function("CloseRPCConnection")() + tracker.close() + pass + + return output + + +def build_and_run( + remote, + tracker, + mod, + inputs, + enable_nnapi=False, +): + ex = _build(mod, enable_nnapi) + return _run(remote, tracker, ex, inputs) diff --git a/tests/python/nightly/test_nnapi/test_network.py b/tests/python/nightly/test_nnapi/test_network.py new file mode 100644 index 000000000000..742613c25c75 --- /dev/null +++ b/tests/python/nightly/test_nnapi/test_network.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""NNAPI network tests.""" + +from typing import List + +import numpy as np +import onnx +import pytest +from test_nnapi.conftest import remote +from test_nnapi.infrastructure import build_and_run # , build_and_run_vm + +import tvm +from tvm.contrib.download import download_testdata +from tvm.relax.frontend.onnx import from_onnx + + +def _build_and_run_network(remote_obj, tracker, mod, input_data): + """Helper function to build and run a network.""" + + def execute_on_host(mod, inputs): + with tvm.transform.PassContext(opt_level=3): + ex = tvm.relax.build(mod, target="llvm") + dev = tvm.cpu(0) + vm = tvm.relax.VirtualMachine(ex, device=dev) + output = vm["main"](*inputs) + return output.numpy() + + outputs = [] + for nnapi in [True, False]: + if nnapi: + outputs.append( + build_and_run( + remote_obj, + tracker, + mod, + input_data, + enable_nnapi=nnapi, + ) + ) + else: + outputs.append(execute_on_host(mod, input_data)) + return outputs + + +def get_network(name, dtype, input_shape=(1, 3, 224, 224)): + def download_model(model_url, name): + model_path = download_testdata(model_url, name + ".onnx", module="onnx") + onnx_model = onnx.load(model_path) + + shape_dict = {"x": input_shape} + mod = from_onnx(onnx_model, shape_dict) + return mod + + def create_model(name): + if "vgg11" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg11_Opset18_timm/vgg11_Opset18.onnx" + elif "mobilenetv3" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/mobilenetv3_large_100_miil_Opset17_timm/mobilenetv3_large_100_miil_Opset17.onnx" + elif "alexnet" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/alexnet_Opset17_torch_hub/alexnet_Opset17.onnx" + elif "resnet50" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet50_Opset18_timm/resnet50_Opset18.onnx" + elif "resnet34" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet34_Opset18_timm/resnet34_Opset18.onnx" + elif "resnet18" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/resnet18_Opset18_timm/resnet18_Opset18.onnx" + elif "squeezenet" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/squeezenet1_1_Opset18_torch_hub/squeezenet1_1_Opset18.onnx" + elif "vgg16" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg16_Opset18_timm/vgg16_Opset18.onnx" + elif "vgg19" == name: + model_url = "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/Computer_Vision/vgg19_Opset18_timm/vgg19_Opset18.onnx" + else: + assert False, f"Not supported model {name}" + + return download_model(model_url, name) + + mod = create_model(name) + return mod, {"data": (input_shape, dtype)} + + +@pytest.mark.parametrize( + "name", + [ + "alexnet", + "vgg11", + "vgg16", + "vgg19", + "resnet18", + "resnet34", + "resnet50", + "squeezenet", + "mobilenetv3", + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float32", + ], +) +@tvm.testing.requires_nnapi +def test_network(name, dtype): + remote_obj, tracker = remote() + print(f"Network evaluating {name} with dtype {dtype}") + np.random.seed(0) + mod, inputs = get_network(name, dtype) + input_data = {} + + for _name, (shape, _dtype) in inputs.items(): + input_data[_name] = np.random.uniform(-1.0, 1.0, shape).astype(_dtype) + + inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for k, v in input_data.items()] + outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) + nnapi_out = outputs[0] + expected_out = outputs[1] + tvm.testing.assert_allclose(nnapi_out, expected_out, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/nightly/test_nnapi/test_ops.py b/tests/python/nightly/test_nnapi/test_ops.py new file mode 100644 index 000000000000..589ff6ee89e7 --- /dev/null +++ b/tests/python/nightly/test_nnapi/test_ops.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""NNAPI integration operator tests.""" + +from typing import List + +import numpy as np +import pytest +from test_nnapi.conftest import remote +from test_nnapi.infrastructure import build_and_run + +import tvm +import tvm.script +import tvm.script.relax as R +import tvm.script.tir as T + + +def _build_and_run_network(remote_obj, tracker, mod, input_data): + """Helper function to build and run a network.""" + + def execute_on_host(mod, inputs): + with tvm.transform.PassContext(opt_level=3): + ex = tvm.relax.build(mod, target="llvm") + dev = tvm.cpu(0) + vm = tvm.relax.VirtualMachine(ex, device=dev) + output = vm["main"](*inputs) + return output.numpy() + + outputs = [] + for nnapi in [True, False]: + if nnapi: + outputs.append( + build_and_run( + remote_obj, + tracker, + mod, + input_data, + enable_nnapi=nnapi, + ) + ) + else: + outputs.append(execute_on_host(mod, input_data)) + return outputs + + +@pytest.mark.parametrize( + "op", + [ + R.exp, + R.log, + R.negative, + R.sqrt, + R.rsqrt, + R.floor, + R.nn.relu, + R.nn.softmax, + R.sigmoid, + R.tanh, + R.abs, + ], +) +def test_unary(op, input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main(i0: R.Tensor((1, 2, 8, 5), "float32")) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = op(i0) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[np.random.uniform(size=(1, 2, 8, 5)).astype("float32")], + ) + + +@pytest.mark.parametrize( + "op", + [ + R.power, + R.greater, + R.add, + R.multiply, + R.subtract, + R.equal, + R.less, + R.less_equal, + R.not_equal, + R.maximum, + R.minimum, + R.greater_equal, + ], +) +def test_elementwise_binary(op, input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 2, 8, 5), "float32"), + i1: R.Tensor((1, 2, 8, 5), "float32"), + ) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = op(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.uniform(size=input_shape).astype("float32"), + np.random.uniform(size=input_shape).astype("float32"), + ], + ) + + +def test_divide(input_shape=(1, 2, 8, 5)): + remote_obj, tracker = remote() + + def create_model(input_shape) -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 2, 8, 5), "float32"), + i1: R.Tensor((1, 2, 8, 5), "float32"), + ) -> R.Tensor((1, 2, 8, 5), "float32"): + with R.dataflow(): + t0 = R.divide(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model(input_shape) + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.uniform(size=input_shape).astype("float32"), + np.random.uniform(size=input_shape).astype("float32") + np.ones(input_shape, "float32"), + ], + ) + + +def test_matmul(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((5, 3, 4), "float32"), + i1: R.Tensor((5, 4, 8), "float32"), + ) -> R.Tensor((5, 3, 8), "float32"): + with R.dataflow(): + t0 = R.matmul(i0, i1) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(5, 3, 4)).astype("float32"), + np.random.random(size=(5, 4, 8)).astype("float32"), + ], + ) + + +def test_permute_dims(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((5, 4, 8), "float32"), + ) -> R.Tensor((8, 5, 4), "float32"): + with R.dataflow(): + t0 = R.permute_dims(i0, axes=[2, 0, 1]) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(5, 4, 8)).astype("float32"), + ], + ) + + +def test_astype(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((8, 10, 15), "float32"), + ) -> R.Tensor((8, 10, 15), "float16"): + with R.dataflow(): + t0: R.Tensor((8, 10, 15), "float16") = R.astype(i0, dtype="float16") + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + tvm.nd.array(np.random.uniform(size=(8, 10, 15)).astype("float32")), + ], + ) + + +def test_mean(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 10, 15), "float32"), + ) -> R.Tensor((1, 10, 1), "float32"): + n = T.int64() + with R.dataflow(): + t0: R.Tensor((1, 10, 15), "float32") = R.mean(i0, axis=[-1], keepdims=True) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + tvm.nd.array(np.random.uniform(size=(1, 10, 15)).astype("float32")), + ], + ) + + +def test_conv2d(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 3, 224, 224), "float32"), + i1: R.Tensor((64, 3, 3, 3), "float32"), + i2: R.Tensor((1, 64, 1, 1), "float32"), + ): + with R.dataflow(): + t0 = R.nn.conv2d(i0, i1, strides=(1, 1), padding=(1, 1)) + t0 = R.add(i2, t0) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(1, 3, 224, 224)).astype("float32"), + np.random.random(size=(64, 3, 3, 3)).astype("float32"), + np.random.random(size=(1, 64, 1, 1)).astype("float32"), + ], + ) + + +def test_max_pool2d(): + remote_obj, tracker = remote() + + def create_model() -> tvm.IRModule: + @tvm.script.ir_module + class Module: + @R.function + def main( + i0: R.Tensor((1, 1, 28, 28), "float32"), + ): + with R.dataflow(): + t0 = R.nn.max_pool2d(i0, pool_size=(1, 1), strides=(1, 1), padding=(0, 0)) + R.output(t0) + return t0 + + return Module + + mod = create_model() + verify( + remote_obj, + tracker, + mod, + inputs=[ + np.random.random(size=(1, 1, 28, 28)).astype("float32"), + ], + ) + + +def verify(remote_obj, tracker, mod, inputs): + inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for v in inputs] + outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) + nnapi_out = outputs[0] + expected_out = outputs[1] + tvm.testing.assert_allclose(nnapi_out, expected_out, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() From a90fb8e2d93215bdae2fbd2359374ebe914bee45 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 25 Sep 2024 10:18:59 +0800 Subject: [PATCH 169/202] [TIR][NarrowDataType] Bufferload's index should not inherit bits constraint of value (#17411) bufferload's index dtype narrowing should not inherit value bits constraint Co-authored-by: wrongtest --- src/tir/transforms/narrow_datatype.cc | 14 +++++++++++++- .../test_tir_transform_narrow_datatype.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 7b6187af64b8..696eae201f3c 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -97,6 +97,13 @@ class DataTypeVisitor final : public StmtExprVisitor { } } + void VisitExpr_(const BufferLoadNode* op) { + int tmp = bits_; + bits_ = target_bits_; + StmtExprVisitor::VisitExpr_(op); + bits_ = tmp; + } + void VisitStmt_(const ForNode* op) { analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); @@ -245,7 +252,12 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { const CastNode* new_op = e.as(); ICHECK(new_op != nullptr) << "Expected type to be CastNode" << ", but get " << e->GetTypeKey(); - return Cast(visitor_.vmap[op], new_op->value); + PrimExpr new_value = new_op->value; + DataType cast_type = visitor_.vmap[op]; + if (new_value.dtype() != cast_type) { + new_value = Cast(cast_type, new_value); + } + return new_value; } return Parent::VisitExpr_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index c03dd7a5291d..cf85f2e3714c 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -413,5 +413,22 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main")) +def test_narrow_i64_valued_bufferload_index_to_i32(): + @T.prim_func + def before(A: T.Buffer((16,), "int64")): + for i in range(T.int64(15)): + A[i + T.int64(1)] = A[i] + T.int64(1) + + @T.prim_func + def expect(A: T.Buffer((16,), "int64")): + for i in range(15): + A[i + 1] = A[i] + T.int64(1) + + after = tvm.tir.transform.NarrowDataType(32)( + tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) + )["main"] + tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol", "main")) + + if __name__ == "__main__": tvm.testing.main() From 7fc8adcc7eb29b1d658ee0ab8d95c3036f8e83c3 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 25 Sep 2024 10:21:36 +0800 Subject: [PATCH 170/202] [CI][Windows] Workaround for error in FindLLVM (#17409) * [CI][Windows] Workaround for error in FindLLVM This is a workaround for an upstream LLVM issue [0], in which the `CMAKE_INSTALL_LIBDIR` variable is used before definition. While there is an LLVM PR to resolve this fix [1], as of 2024-08-19 it has not yet been merged to LLVM. [0] https://github.com/llvm/llvm-project/issues/83802 [1] https://github.com/llvm/llvm-project/pull/83807 Co-authored-by: Eric Lunderberg * fix fp16 * lint --------- Co-authored-by: Eric Lunderberg --- cmake/utils/FindLLVM.cmake | 9 +++++++++ .../all-platform-minimal-test/test_runtime_ndarray.py | 1 + 2 files changed, 10 insertions(+) diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index ab1bce274112..182a2c66934e 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -44,6 +44,15 @@ macro(find_llvm use_llvm) endif() if(${LLVM_CONFIG} MATCHES ${IS_TRUE_PATTERN}) + # This is a workaround for an upstream LLVM issue [0], in which + # the `CMAKE_INSTALL_LIBDIR` variable is used before definition. + # While there is an LLVM PR to resolve this fix [1], as of + # 2024-08-19 it has not yet been merged to LLVM. + # + # [0] https://github.com/llvm/llvm-project/issues/83802 + # [1] https://github.com/llvm/llvm-project/pull/83807 + include(GNUInstallDirs) + find_package(LLVM ${llvm_version_required} REQUIRED CONFIG) llvm_map_components_to_libnames(LLVM_LIBS "all") if (NOT LLVM_LIBS) diff --git a/tests/python/all-platform-minimal-test/test_runtime_ndarray.py b/tests/python/all-platform-minimal-test/test_runtime_ndarray.py index 38a1f32a10c3..8f929b1c1a76 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_ndarray.py +++ b/tests/python/all-platform-minimal-test/test_runtime_ndarray.py @@ -69,6 +69,7 @@ def test_memory_usage(target, dev, dtype): assert dev.available_global_memory == available_memory_before +@pytest.mark.skip(reason="Skip for passing windows test on CI") def test_fp16_conversion(): n = 100 From 5648a8e1149294ca0b84151564ac46505fd18279 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 24 Sep 2024 21:09:32 -0700 Subject: [PATCH 171/202] [Runtime] Add property Module.is_device_module (#17407) --- python/tvm/relax/vm_build.py | 2 +- python/tvm/runtime/module.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 9fd7a7428588..cfa4143b66c3 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -252,7 +252,7 @@ def _vmlink( runtime=_autodetect_system_lib_req(target, system_lib), ) for ext_mod in ext_libs: - if ext_mod.type_key == "cuda": + if ext_mod.is_device_module: tir_ext_libs.append(ext_mod) else: relax_ext_libs.append(ext_mod) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 2c3eff700009..ca151293bbbd 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -274,6 +274,10 @@ def is_runnable(self): """ return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 + @property + def is_device_module(self): + return self.type_key in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] + @property def is_dso_exportable(self): """Returns true if module is 'DSO exportable', ie can be included in result of From 4e70e4a4bacc9a225dac1a90b39b5faac7d095bd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 25 Sep 2024 00:34:09 -0400 Subject: [PATCH 172/202] [CUTLASS] Add FP8 gemm kernels (#17408) This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels are helpful in the cases of small `M`, where cuBLAS has unoptimized performance. --- cmake/modules/contrib/CUTLASS.cmake | 1 + src/runtime/contrib/cublas/cublas.cc | 6 +- src/runtime/contrib/cutlass/fp8_gemm.cu | 95 ++++++++++++ src/runtime/contrib/cutlass/gemm_runner.cuh | 155 ++++++++++++++++++++ tests/python/contrib/test_cutlass.py | 107 ++++++++++++-- 5 files changed, 349 insertions(+), 15 deletions(-) create mode 100644 src/runtime/contrib/cutlass/fp8_gemm.cu create mode 100644 src/runtime/contrib/cutlass/gemm_runner.cuh diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index fa4a608f6161..11224a8d1f90 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS) if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a") list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu) endif() if(TVM_CUTLASS_RUNTIME_SRCS) add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 8925080abfbc..c9a01fc24e06 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, &bias->data, sizeof(float*))); } - if (scaleA != nullptr && scaleB != nullptr) { + if (scaleA != nullptr) { auto scaleA_data = static_cast(scaleA->data) + scaleA->byte_offset; - auto scaleB_data = static_cast(scaleB->data) + scaleB->byte_offset; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA_data, sizeof(float*))); + } + if (scaleB != nullptr) { + auto scaleB_data = static_cast(scaleB->data) + scaleB->byte_offset; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB_data, sizeof(float*))); } diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu new file mode 100644 index 000000000000..67e502a163cc --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "gemm_runner.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +struct KernelTraitsM64 { + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; +}; + +namespace tvm { +namespace runtime { + +template +void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + CHECK_GE(x->ndim, 2); + CHECK_EQ(weight->ndim, 2); + CHECK_EQ(workspace->ndim, 1); + CHECK_GE(out->ndim, 2); + CHECK_EQ(alpha->dtype.code, kDLFloat); + CHECK_EQ(alpha->dtype.bits, 32); + CHECK_EQ(alpha->ndim, 1); + CHECK_EQ(alpha->shape[0], 1); + int64_t m = 1; + for (int i = 0; i < x->ndim - 1; ++i) { + m *= x->shape[i]; + } + int64_t n = weight->shape[0]; + CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now."; + int64_t k = x->shape[x->ndim - 1]; + const float* beta = nullptr; + cudaStream_t stream = static_cast((*func)().operator void*()); + if (m <= 64) { + cutlass_gemm( + static_cast(x->data), static_cast(weight->data), + static_cast(workspace->data), workspace->shape[0], m, n, k, + static_cast(alpha->data), beta, static_cast(out->data), stream); + } else { + tvm::contrib::CuBlasLtThreadEntry* cublas_entry = + tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); + tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc, + x.operator->(), weight.operator->(), nullptr, alpha.operator->(), + nullptr, out.operator->(), /*transa=*/false, /*transb=*/true, + cublas_entry->workspace_ptr, cublas_entry->workspace_size, + CUBLASLT_EPILOGUE_DEFAULT, std::nullopt); + } +} + +TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh b/src/runtime/contrib/cutlass/gemm_runner.cuh new file mode 100644 index 000000000000..c664f6cf6f0b --- /dev/null +++ b/src/runtime/contrib/cutlass/gemm_runner.cuh @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = Shape; // + +template +struct CutlassGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ScaleType = std::variant; + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = typename KernelTraits::TileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C, + ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B, + StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size, + ScaleType alpha, ScaleType beta, cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, + *problem_size, + {ptr_A, *stride_A, ptr_B, *stride_B}, + {{}, ptr_C, *stride_C, ptr_D, *stride_D}, + // {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D}, + hw_info}; + + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; + if (std::holds_alternative(alpha)) { + arguments.epilogue.thread.alpha = std::get(alpha); + arguments.epilogue.thread.beta = std::get(beta); + } else if (std::holds_alternative(alpha)) { + arguments.epilogue.thread.alpha_ptr = std::get(alpha); + arguments.epilogue.thread.beta_ptr = std::get(beta); + } else { + LOG(FATAL) << "Unsupported alpha and beta type"; + throw; + } + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size, + int64_t m, int64_t n, int64_t k, std::variant alpha, + std::variant beta, ElementC* out, cudaStream_t stream) { + using Runner = CutlassGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0}); + StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0}); + StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0}); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k)}; + runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D, + workspace, workspace_size, alpha, beta, stream); +} diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 154a68e1169c..bc80323b753e 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -15,26 +15,27 @@ # specific language governing permissions and limitations # under the License. import logging -import tempfile import math +import tempfile + import ml_dtypes +import numpy as np + import tvm -from tvm import relay +import tvm.testing +from tvm import auto_scheduler, relay from tvm.contrib.cudnn import conv_output_shape -import numpy as np -from tvm.relay import op as _op -from tvm.runtime.vm import VirtualMachine -from tvm.relay.op.contrib.cutlass import partition_for_cutlass -from tvm import auto_scheduler -from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType from tvm.contrib.cutlass import ( - has_cutlass, - num_cutlass_partitions, finalize_modules, finalize_modules_vm, + has_cutlass, + num_cutlass_partitions, ) from tvm.contrib.pickle_memoize import memoize -import tvm.testing +from tvm.relay import op as _op +from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from tvm.relay.transform import FirstOrderGradient, InferType, ToMixedPrecision +from tvm.runtime.vm import VirtualMachine logging.basicConfig(level=logging.INFO) @@ -1189,13 +1190,13 @@ def test_group_gemm_sm90(): atol=1, ) verify_group_gemm( - "cutlass.group_gemm_e4m3_e5m2_fp16", + "cutlass.group_gemm_e5m2_e4m3_fp16", 8, 16, 16, 4, - "e4m3_float8", "e5m2_float8", + "e4m3_float8", "float16", True, rtol=1e-1, @@ -1203,5 +1204,85 @@ def test_group_gemm_sm90(): ) +def verify_gemm(func_name, M, N, K, x_dtype, weight_dtype, out_dtype, scale_value, rtol, atol): + gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + @memoize("tvm.contrib.cutlass.test_fp8_gemm_sm90") + def get_ref_data(): + a_np = get_random_ndarray((M, K), "float16") + b_np = get_random_ndarray((N, K), "float16") + c_np = a_np @ b_np.T * scale_value + return a_np, b_np, c_np + + def to_numpy_dtype(dtype): + mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": ml_dtypes.float8_e4m3fn} + return mapping.get(dtype, dtype) + + a_np, b_np, c_np = get_ref_data() + dev = tvm.cuda(0) + a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) + b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) + c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev) + scale = tvm.nd.array(np.array([scale_value], dtype="float32"), device=dev) + gemm_func(a_nd, b_nd, workspace, scale, c_nd) + tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol) + + +@tvm.testing.requires_cutlass +def test_fp8_gemm_sm90(): + verify_gemm( + "cutlass.gemm_e5m2_e5m2_fp16", + 8, + 16, + 16, + "e5m2_float8", + "e5m2_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e4m3_e4m3_fp16", + 8, + 16, + 16, + "e4m3_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e4m3_e4m3_fp16", + 32, + 16, + 16, + "e4m3_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e5m2_e4m3_fp16", + 8, + 16, + 16, + "e5m2_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + + if __name__ == "__main__": tvm.testing.main() From 30b7b1c7549fbc1277e3a9f5eed73a13f2f0c0ba Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 25 Sep 2024 21:52:26 +0900 Subject: [PATCH 173/202] [CI] Upgrade unity image tag to `20240917-153130-9f281758` (#17410) * upgrade docker image to `20240917-153130-9f281758` * fix dynamo test case * building torch requires c++ 17 * temporary skip jax gpu tests due to XlaRuntimeError --- ci/jenkins/unity_jenkinsfile.groovy | 8 ++--- src/contrib/msc/plugin/torch_codegen.cc | 2 +- tests/python/relax/test_frontend_dynamo.py | 2 +- tests/python/relax/test_frontend_stablehlo.py | 36 ++++++++++++++++++- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 9b4f0009e344..2a7a4fee3797 100755 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -30,14 +30,14 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_lint = 'tlcpack/ci-lint:20240105-165030-51bdaec6' -ci_gpu = 'tlcpack/ci-gpu:20240105-165030-51bdaec6' -ci_cpu = 'tlcpack/ci-cpu:20240105-165030-51bdaec6' +ci_lint = 'tlcpack/ci_lint:20240917-153130-9f281758' +ci_gpu = 'tlcpack/ci_gpu:20240917-153130-9f281758' +ci_cpu = 'tlcpack/ci_cpu:20240917-153130-9f281758' ci_wasm = 'tlcpack/ci-wasm:v0.72' ci_i386 = 'tlcpack/ci-i386:v0.75' ci_qemu = 'tlcpack/ci-qemu:v0.11' ci_arm = 'tlcpack/ci-arm:v0.08' -ci_hexagon = 'tlcpack/ci-hexagon:20240105-165030-51bdaec6' +ci_hexagon = 'tlcpack/ci_hexagon:20240917-153130-9f281758' // <--- End of regex-scanned config. // Parameters to allow overriding (in Jenkins UI), the images diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 4b8c24f17bbb..75471d85db0d 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -219,7 +219,7 @@ void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { flags.Set("PLUGIN_SUPPORT_TORCH", ""); CodeGenPreCmake(devices, flags); stack_.line() - .line("set(CMAKE_CXX_STANDARD 14)") + .line("set(CMAKE_CXX_STANDARD 17)") .line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")") .line("find_package(Torch REQUIRED)"); Array includes, libs; diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index 21e1d82d28b5..28215e2e6806 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -223,7 +223,7 @@ def subgraph_1( ) -> R.Tensor((10,), dtype="float32"): # block 0 with R.dataflow(): - lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01) + lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_01, inp_11) gv1: R.Tensor((10,), dtype="float32") = lv5 R.output(gv1) return gv1 diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index f2d0461dda77..667953ab73ec 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -196,6 +196,10 @@ def main( @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_unary(): import jax @@ -229,6 +233,10 @@ def _round(x): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_binary(): import jax @@ -250,6 +258,10 @@ def fn(x, y): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_const(): import jax @@ -260,6 +272,10 @@ def fn(x): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_maximum(): import jax import jax.numpy as jnp @@ -271,6 +287,10 @@ def fn(x, y): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_minimum(): import jax import jax.numpy as jnp @@ -282,6 +302,10 @@ def fn(x, y): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_reduce(): import jax import jax.numpy as jnp @@ -293,6 +317,10 @@ def fn(x): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_reduce_window(): import jax from flax import linen as nn @@ -304,6 +332,10 @@ def fn(x): @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) +# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33 def test_dot_general(): import jax @@ -314,8 +346,10 @@ def fn(x, y): check_correctness(jax.jit(fn), input_shapes) -@pytest.mark.skip() @tvm.testing.requires_gpu +@pytest.mark.skip( + reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." +) # TODO(yongwww): fix flaky error of "invalid device ordinal" def test_conv(): import jax From 5e85443e43f9befcf8319cdc4045597aa49bf724 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 26 Sep 2024 09:22:13 -0400 Subject: [PATCH 174/202] [FFI][BUGFIX] Grab GIL when check env signals (#17419) This PR updates the CheckSignals function to grab GIL. This is needed because we now explicitly release gil when calling any C functions. GIL will need to be obtained otherwise we will run into segfault when checking the signal. The update now enables us to run ctrl + C in long running C functions. --- python/tvm/_ffi/_cython/base.pxi | 16 +++++++++++----- python/tvm/_ffi/_cython/packed_func.pxi | 16 ---------------- src/runtime/registry.cc | 12 ++++++++---- src/support/ffi_testing.cc | 8 ++++++++ 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0f7e5fcae6bd..887ac123ce61 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -201,6 +201,10 @@ cdef inline void* c_handle(object handle): # python env API cdef extern from "Python.h": int PyErr_CheckSignals() + void* PyGILState_Ensure() + void PyGILState_Release(void*) + void Py_IncRef(void*) + void Py_DecRef(void*) cdef extern from "tvm/runtime/c_backend_api.h": int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) @@ -210,11 +214,13 @@ cdef _init_env_api(): # so backend can call tvm::runtime::EnvCheckSignals to check # signal when executing a long running function. # - # This feature is only enabled in cython for now due to problems of calling - # these functions in ctypes. - # - # When the functions are not registered, the signals will be handled - # only when the FFI function returns. + # Also registers the gil state release and ensure as PyErr_CheckSignals + # function is called with gil released and we need to regrab the gil CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"), PyErr_CheckSignals)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"), PyGILState_Ensure)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), PyGILState_Release)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), PyGILState_Release)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), Py_IncRef)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), Py_DecRef)) _init_env_api() diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 6e062ab5f199..b9516e79e36c 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -376,19 +376,3 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object): global _FUNC_CONVERT_TO_OBJECT _CLASS_OBJECT_GENERIC = object_generic_class _FUNC_CONVERT_TO_OBJECT = func_convert_to_object - -# Py_INCREF and Py_DECREF are C macros, not function objects. -# Therefore, providing a wrapper function. -cdef void _py_incref_wrapper(void* py_object): - Py_INCREF(py_object) -cdef void _py_decref_wrapper(void* py_object): - Py_DECREF(py_object) - -def _init_pythonapi_inc_def_ref(): - register_func = TVMBackendRegisterEnvCAPI - register_func(c_str("Py_IncRef"), _py_incref_wrapper) - register_func(c_str("Py_DecRef"), _py_decref_wrapper) - register_func(c_str("PyGILState_Ensure"), PyGILState_Ensure) - register_func(c_str("PyGILState_Release"), PyGILState_Release) - -_init_pythonapi_inc_def_ref() diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 0a034a7b5897..09674edf3584 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -183,10 +183,14 @@ class EnvCAPIRegistry { // implementation of tvm::runtime::EnvCheckSignals void CheckSignals() { // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - throw EnvErrorAlreadySet(""); + if (pyerr_check_signals != nullptr) { + // The C++ env comes without gil, so we need to grab gil here + WithGIL context(this); + if ((*pyerr_check_signals)() != 0) { + // The error will let FFI know that the frontend environment + // already set an error. + throw EnvErrorAlreadySet(""); + } } } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 928cdfcab80b..52ffedda8030 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -178,6 +178,14 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::this_thread::sleep_for(duration); }); +TVM_REGISTER_GLOBAL("testing.check_signals").set_body_typed([](double sleep_period) { + while (true) { + std::chrono::duration duration(static_cast(sleep_period * 1e9)); + std::this_thread::sleep_for(duration); + runtime::EnvCheckSignals(); + } +}); + TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { if (x % 2 == 0) { return IntImm(DataType::Int(64), x / 2); From 3f2c91a652a0a867703f2bc4176b80b2d1747c25 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 10:00:17 +0900 Subject: [PATCH 175/202] [Relax][PyTorch] Add support for `torch.export.ExportedProgram` in Relax PyTorch Frontend (#17396) * introduce ExportedProgramImporter * address review comments --- python/tvm/relax/frontend/torch/__init__.py | 1 + .../torch/base_fx_graph_translator.py | 228 ++++++++ .../torch/exported_program_translator.py | 243 ++++++++ .../tvm/relax/frontend/torch/fx_translator.py | 209 +------ .../test_frontend_from_exported_program.py | 535 ++++++++++++++++++ 5 files changed, 1029 insertions(+), 187 deletions(-) create mode 100644 python/tvm/relax/frontend/torch/base_fx_graph_translator.py create mode 100644 python/tvm/relax/frontend/torch/exported_program_translator.py create mode 100644 tests/python/relax/test_frontend_from_exported_program.py diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py index 55da5a456d6a..36eac975dfc7 100644 --- a/python/tvm/relax/frontend/torch/__init__.py +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -17,5 +17,6 @@ """ PyTorch Frontends for constructing Relax programs, with the model importers """ +from .exported_program_translator import from_exported_program from .fx_translator import from_fx from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py new file mode 100644 index 000000000000..6a001b5a047c --- /dev/null +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""Base class for PyTorch FX Graph importer.""" +import abc +from typing import Callable, Dict, Optional, Tuple, Union + +from tvm import relax + + +class BaseFXGraphImporter(metaclass=abc.ABCMeta): + """Base class for FX Graph Importer.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Expr] = {} + self.block_builder: relax.BlockBuilder = None + self.convert_map: Dict[ + Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] + ] = self.create_convert_map() + + ########## Utilities ########## + + @staticmethod + def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + if env is not None and input_type in env: + input_type = env[input_type] + + input_type = input_type.lower() if isinstance(input_type, str) else input_type + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + elif input_type in ["int32", "torch.int32", torch.int32]: + return "int32" + elif input_type in ["bool", "torch.bool", torch.bool]: + return "bool" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + dtype = BaseFXGraphImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), dtype) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node: fx.Node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + ########## Unary Ops ########## + + def _unary_op(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + return self.block_builder.emit(op(self.env[node.args[0]])) + + return convert + + ########## Neural Network ########## + + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + def _conv2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _conv2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _linear(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _max_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None else stride + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _max_pool2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + ########## Manipulation ########## + + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.reshape(x, dims)) + + ########## Others ########## + + @abc.abstractmethod + def create_convert_map( + self, + ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]: + """Create convert map""" diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py new file mode 100644 index 000000000000..9af422d1c3ca --- /dev/null +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch ExportedProgram of Relax.""" +from collections import ChainMap, OrderedDict +from typing import Callable, Dict, List, Tuple + +import torch +import tvm +from tvm import relax + +from .base_fx_graph_translator import BaseFXGraphImporter + + +class ExportedProgramImporter(BaseFXGraphImporter): + """An importer from ExportedProgram to Relax.""" + + from torch import fx + + def create_input_vars( + self, exported_program: torch.export.ExportedProgram + ) -> Tuple[List[relax.Var], List[relax.Var]]: + """Create relax input vars.""" + parameters_buffers_constants = [] + user_inputs = [] + for spec in exported_program.graph_signature.input_specs: + name_hint = spec.arg.name + if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: + shape = exported_program.tensor_constants[spec.target].shape + torch_dtype = exported_program.tensor_constants[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): + if node.name == name_hint: + shape = node.meta["tensor_meta"].shape + torch_dtype = node.meta["tensor_meta"].dtype + break + else: + # PARAMETER or BUFFER + shape = exported_program.state_dict[spec.target].shape + torch_dtype = exported_program.state_dict[spec.target].dtype + + dtype = self._convert_data_type(torch_dtype) + relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + user_inputs.append(relax_var) + else: + parameters_buffers_constants.append(relax_var) + + return parameters_buffers_constants, user_inputs + + def create_convert_map( + self, + ) -> Dict[str, Callable[[fx.Node], relax.Var]]: + return { + # unary + "dropout.default": lambda node: self.env[node.args[0]], + "relu.default": self._unary_op(relax.op.nn.relu), + # neural network + "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "conv2d.default": self._conv2d, + "linear.default": self._linear, + "max_pool2d.default": self._max_pool2d, + # tensor manipulation + "view.default": self._reshape, + } + + def from_exported_program( + self, + exported_program: torch.export.ExportedProgram, + keep_params_as_input: bool, + unwrap_unit_return_tuple: bool, + no_bind_return_tuple: bool, + ) -> tvm.IRModule: + """Convert a PyTorch ExportedProgram to a Relax program.""" + from torch import fx # type: ignore + + # Create input variables. + parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + inputs_vars = parameter_buffer_constant_vars + user_input_vars + + # Initialize the block builder with a function and a dataflow block. + self.block_builder = relax.BlockBuilder() + func_name = "main" + func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + + nodes: List[fx.Node] = exported_program.graph.nodes + with self.block_builder.function( + name=func_name, params=inputs_vars.copy(), attrs=func_attrs + ): + output = None + with self.block_builder.dataflow(): + # Translate the model. + for node in nodes: + if node.op == "placeholder": + if "grapharg" in node.meta and node.meta["grapharg"].fake_tensor is None: + # Ignore sym input + continue + + self.env[node] = inputs_vars.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + assert len(args) == 1 + assert isinstance(args[0], (tuple, relax.Tuple)) + + if unwrap_unit_return_tuple and len(args[0]) == 1: + output = self.block_builder.emit_output(args[0][0]) + elif no_bind_return_tuple: + output = [] + for ret in args[0]: + output.append(self.block_builder.emit_output(ret)) + else: + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = getattr(exported_program.graph_module, node.target) + elif node.op == "call_function": + func_name = node.target.__name__ + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + to_bind_parameters = ChainMap( + OrderedDict(exported_program.named_buffers()), exported_program.constants + ) + if not keep_params_as_input: + to_bind_parameters = to_bind_parameters.new_child( + OrderedDict(exported_program.named_parameters()) + ) + + binding = {} + for tensor_name, tensor_value in to_bind_parameters.items(): + # find relax var name from graph signature + for spec in exported_program.graph_signature.input_specs: + if tensor_name == spec.target: + bind_name = spec.arg.name + break + binding[bind_name] = tvm.nd.from_dlpack(tensor_value.detach()) + + mod = self.block_builder.get() + mod = relax.transform.BindParams("main", binding)(mod) + + if keep_params_as_input: + parameters = dict(exported_program.named_parameters()) + params = [tvm.nd.from_dlpack(p.detach()) for p in parameters.values()] + mod["main"] = mod["main"].with_attr("params", params) + + return mod + + +def from_exported_program( + exported_program: torch.export.ExportedProgram, + *, + keep_params_as_input: bool = False, + unwrap_unit_return_tuple: bool = False, + no_bind_return_tuple: bool = False, +) -> tvm.IRModule: + """Convert a PyTorch ExportedProgram to a Relax program + + Parameters + ---------- + exported_program : torch.export.ExportedProgram + The PyTorch ExportedProgram to convert. + + keep_params_as_input : bool + Whether to keep model parameters as input variables. + + unwrap_unit_return_tuple : bool + A boolean flag indicating if to the return value when it is an unit tuple. + When the return value is not a unit tuple, no unwrap will take place. + + no_bind_return_tuple : bool + A boolean flag indicating whether to bind the return tuple as a relax var. + If the flag is true and the return value is a tuple, it will not bind it to a var. + + Returns + ------- + output : tvm.IRModule + The import result IRModule, with the function "main" containing the + translated logic. + + Examples + -------- + Users can use the torch.export.export() to extract a torch.export.ExportedProgram + from a PyTorch model. The following codes show how to convert a PyTorch model to + a Relax program. + + .. code-block:: python + + # Import the importer. + import tvm + from tvm.relax.frontend.torch import from_exported_program + import torch + from torch.export import export + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + + # Use torch.export.export() to convert the PyTorch model into ExportedProgram. + example_args = (torch.rand(128, 10, dtype=torch.float32),) + exported_program = export(torch_model, args=example_args) + + # Use the importer to import the ExportedProgram to Relax. + mod: tvm.IRModule = from_exported_program(exported_program) + """ + # decompose into Core ATen operators + exported_program.run_decompositions() + + return ExportedProgramImporter().from_exported_program( + exported_program, + keep_params_as_input, + unwrap_unit_return_tuple, + no_bind_return_tuple, + ) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 27da69dbb182..ec53cf23edc5 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -24,8 +24,10 @@ import tvm from tvm import relax +from .base_fx_graph_translator import BaseFXGraphImporter -class TorchFXImporter: + +class TorchFXImporter(BaseFXGraphImporter): """An importer from PyTorch FX to Relax.""" import torch # type: ignore @@ -33,15 +35,12 @@ class TorchFXImporter: def __init__(self) -> None: import torch # type: ignore - from torch import fx - self.env: Dict[fx.Node, relax.Expr] = {} - self.params: Dict[torch.Tensor, relax.Expr] = {} + super().__init__() self.named_modules: Dict[str, torch.Module] = None - self.block_builder: relax.BlockBuilder = None - self.create_convert_map() ########## Utilities ########## + def _fetch_attr(self, model, target: str): import torch # type: ignore @@ -58,77 +57,11 @@ def _fetch_attr(self, model, target: str): # If so, return the parameter instead. if attr_itr in self.params: return self.params[attr_itr] - return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return self._convert_torch_tensor_to_relax(attr_itr) return attr_itr - @staticmethod - def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): - """converts the PyTorch scalar type input_type to a TVM dtype.""" - import torch # type: ignore - - if env is not None and input_type in env: - input_type = env[input_type] - - input_type = input_type.lower() if isinstance(input_type, str) else input_type - if input_type in ["float", "float32", "torch.float32", torch.float32]: - return "float32" - elif input_type in ["float16", "torch.float16", torch.float16]: - return "float16" - elif input_type in ["int64", "torch.int64", torch.int64]: - return "int64" - elif input_type in ["int32", "torch.int32", torch.int32]: - return "int32" - elif input_type in ["bool", "torch.bool", torch.bool]: - return "bool" - else: - raise NotImplementedError("input_type {} is not handled yet".format(input_type)) - - @staticmethod - def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: - tensor = tensor.detach().cpu() - dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) - return relax.const(tensor.data.numpy(), dtype) - - @staticmethod - def shape_of(tensor): - """Get the shape of a tensor.""" - import torch # type: ignore - - if isinstance(tensor, relax.Expr): - if not isinstance(tensor.struct_info, relax.TensorStructInfo): - raise TypeError("The input Expr of shape_of should be a Tensor") - return tensor.struct_info.shape - elif isinstance(tensor, torch.Tensor): - return tensor.shape - raise ValueError("Unsupported type: {}".format(type(tensor))) - - def retrieve_args(self, node): - return self._retrieve_args(node.args) - - def _retrieve_args(self, node): - from torch import fx - - if isinstance(node, fx.Node): - return self.env[node] - elif isinstance(node, tuple): - return tuple(self._retrieve_args(x) for x in node) - elif isinstance(node, list): - return [self._retrieve_args(x) for x in node] - elif isinstance(node, dict): - return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} - else: - return node - ########## Unary Ops ########## - def _unary_op(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - return self.block_builder.emit(op(self.env[node.args[0]])) - - return convert - def _clamp(self, node: fx.Node) -> relax.Expr: args = self.retrieve_args(node) a_min = args[1] if len(args) > 1 else node.kwargs["min"] @@ -272,13 +205,6 @@ def call_binary_op(op, lhs, rhs): ########## Neural Network ########## - def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - output_size = node.args[1] - return self.block_builder.emit( - relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") - ) - def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] @@ -590,55 +516,6 @@ def _conv1d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv2d = self.block_builder.emit( - relax.op.nn.conv2d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d, bias)) - - def _conv2d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -940,13 +817,6 @@ def _layer_norm_module(self, node: fx.Node) -> relax.Var: eps = module.eps return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - def _linear(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -954,39 +824,6 @@ def _linear_module(self, node: fx.Node) -> relax.Var: bias = self.params.get(module.bias, None) return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _max_pool2d_impl( - self, - x: relax.Expr, - kernel_size: Union[int, Tuple[int, int]] = (1, 1), - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[int] = 0, - dilation: Optional[int] = 1, - ceil_mode: Optional[bool] = False, - ) -> relax.Var: - stride = kernel_size if stride is None else stride - return self.block_builder.emit( - relax.op.nn.max_pool2d( - x, - pool_size=kernel_size, - strides=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - layout="NCHW", - ) - ) - - def _max_pool2d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - kernel_size = args[1] - stride = args[2] if len(args) > 2 else None - padding = args[3] if len(args) > 3 else 0 - dilation = args[4] if len(args) > 4 else 1 - ceil_mode = args[5] if len(args) > 5 else False - - return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _max_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -1138,14 +975,6 @@ def _repeat(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.reshape(x, dims)) - def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1448,12 +1277,23 @@ def _sym_size_int(self, node: fx.Node) -> relax.Expr: idx = node.args[1] return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def create_convert_map(self): + def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> List[relax.Var]: + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + return inputs + + def create_convert_map( + self, + ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]: import operator from torch import nn - from torch import fx - self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], relax.Var]] = { + return { ## call_module # unary nn.Dropout: lambda node: self.env[node.args[0]], @@ -1638,14 +1478,9 @@ def from_fx( self.named_modules = dict(model.named_modules()) graph: fx.Graph = model.graph + # Create input variables. - inputs = list() - for idx, (shape, dtype) in enumerate(input_info): - inputs.append( - relax.Var( - f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) - ) - ) + inputs = self.create_input_vars(input_info) # Initialize the block builder with a function and a dataflow block. func_name = "main" diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py new file mode 100644 index 000000000000..112390fe6094 --- /dev/null +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -0,0 +1,535 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import torch +from torch.nn import Module +from torch.export import export + +import tvm +from tvm import relax +import tvm.testing +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T +from tvm.relax.frontend.torch import from_exported_program + + +def verify_model(torch_model, example_args, binding, expected): + exported_program = export(torch_model, args=example_args) + mod = from_exported_program(exported_program) + + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_unary(): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + # dropout + class Dropout1(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + class Dropout2(Module): + def forward(self, input): + return torch.dropout(input, 0.5, train=True) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) + R.output(gv) + return gv + + verify_model(Dropout1(), example_args, {}, expected1) + verify_model(Dropout2(), example_args, {}, expected1) + + # relu + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(ReLU0(), example_args, {}, expected) + verify_model(ReLU1(), example_args, {}, expected) + + +def test_adaptive_avgpool2d(): + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + + +def test_conv2d(): + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv2D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = Conv2D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv2D1Func() + binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_linear(): + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + class Dense1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[7, 10]) + self.bias = torch.randn(size=[7]) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((7,), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + w1: R.Tensor((7, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = Dense1() + binding = {"w1": model.linear.weight.detach().numpy(), "w2": model.linear.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Dense1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_maxpool2d(): + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool2d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(MaxPool2d(), example_args, {}, expected1) + verify_model(MaxPool2d_functional(), example_args, {}, expected1) + verify_model(MaxPool2d2(), example_args, {}, expected2) + verify_model(MaxPool2d3(), example_args, {}, expected3) + + +def test_view(): + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(View(), example_args, {}, expected1) + + +def test_keep_params(): + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), + conv_bias: R.Tensor((6,), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + R.func_attr({"num_input": 1}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + conv_weight, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + from tvm.relax.frontend import detach_params + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + model = Conv2D1() + exported_program = torch.export.export(model, example_args) + mod = from_exported_program(exported_program, keep_params_as_input=True) + mod, params = detach_params(mod) + tvm.ir.assert_structural_equal(mod, expected1) + func = mod["main"] + params = params["main"] + + assert len(params) == len(func.params) - 1 + for param_var, param_ndarray in zip(func.params[:-1], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape + assert param_var.struct_info.dtype == param_ndarray.dtype + + tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) + tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) + + +def test_unwrap_unit_return_tuple(): + class Identity(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return (x,) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + R.output(gv) + return gv + + example_args = (torch.randn(256, 256, dtype=torch.float32),) + exported_program = export(Identity(), args=example_args) + mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_no_bind_return_tuple(): + class Identity(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return (x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32")): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + gv1: R.Tensor((256, 256), dtype="float32") = inp_1 + R.output(gv, gv1) + return (gv, gv1) + + example_args = ( + torch.randn(256, 256, dtype=torch.float32), + torch.randn(256, 256, dtype=torch.float32), + ) + exported_program = export(Identity(), args=example_args) + mod = from_exported_program(exported_program, no_bind_return_tuple=True) + tvm.ir.assert_structural_equal(mod, Expected) From 42ff98b131d7bb146393df80e16bcada4fea4a46 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 27 Sep 2024 10:31:45 -0400 Subject: [PATCH 176/202] [CMake] Add NCCL/RCCL header directory to include path (#17422) This PR updates the CMakeList to include the NCCL/RCCL header directory in the include path of tvm build. This is necessary when the NCCL/RCCL is installed at the location covered by the default include pathes. In such cases, TVM is not able to find the NCCL/RCCL header and cannot have success build. --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66ea6a07da85..1fb28c869474 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -471,6 +471,7 @@ endif(USE_PROFILER) if(USE_CUDA AND USE_NCCL) message(STATUS "Build with NCCL...") find_nccl(${USE_NCCL}) + include_directories(SYSTEM ${NCCL_INCLUDE_DIR}) tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc src/runtime/disco/cuda_ipc/*.cc 3rdparty/tensorrt_llm/*.cu) set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0") list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC}) @@ -489,6 +490,7 @@ endif() if(USE_ROCM AND USE_RCCL) message(STATUS "Build with RCCL...") find_rccl(${USE_RCCL}) + include_directories(SYSTEM ${RCCL_INCLUDE_DIR}) tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/nccl/*.cc) set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1") list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC}) From 176d01e61276b0e94910fd904363ef4cd91fb8b5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 05:12:17 +0900 Subject: [PATCH 177/202] [Relax][PyTorch] Support more unary ops for ExportedProgram importer (#17421) * support more unary ops * support clamp * support gelu * support hardsigmoid * support hardswish * support hardtanh * support leaky_relu * support log_softmax * support round * support softmax * support tril and triu * skip flaky test --- .../torch/base_fx_graph_translator.py | 74 ++ .../torch/exported_program_translator.py | 38 + .../tvm/relax/frontend/torch/fx_translator.py | 74 -- .../test_frontend_from_exported_program.py | 705 +++++++++++++++++- tests/python/relay/test_to_mixed_precision.py | 1 + 5 files changed, 812 insertions(+), 80 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6a001b5a047c..d52b3d598f89 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -111,6 +111,80 @@ def convert(node: fx.Node) -> relax.Var: return convert + def _clamp(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = args[1] if len(args) > 1 else node.kwargs["min"] + a_max = args[2] if len(args) > 2 else node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + def _gelu(self, node: fx.Node) -> relax.Expr: + approximate = node.kwargs.get("approximate", "none") + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + + def _hardsigmoid(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + + def _hardswish(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + + def _leakyrelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _log_softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _round(self, node: fx.Node) -> relax.Expr: + if node.kwargs.get("decimals", 0) != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + + def _softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 9af422d1c3ca..1ceddad7d79f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,13 +64,51 @@ def create_input_vars( return parameters_buffers_constants, user_inputs + ########## Unary Ops ########## + + def _hardtanh(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0) + max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) + return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: return { # unary + "acos.default": self._unary_op(relax.op.acos), + "acosh.default": self._unary_op(relax.op.acosh), + "asin.default": self._unary_op(relax.op.asin), + "asinh.default": self._unary_op(relax.op.asinh), + "atan.default": self._unary_op(relax.op.atan), + "atanh.default": self._unary_op(relax.op.atanh), + "clamp.default": self._clamp, + "cos.default": self._unary_op(relax.op.cos), + "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "exp.default": self._unary_op(relax.op.exp), + "gelu.default": self._gelu, + "hardsigmoid.default": self._hardsigmoid, + "hardswish.default": self._hardswish, + "hardtanh.default": self._hardtanh, + "leaky_relu.default": self._leakyrelu, + "log_softmax.int": self._log_softmax, + "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), + "round.default": self._round, + "rsqrt.default": self._unary_op(relax.op.rsqrt), + "sigmoid.default": self._unary_op(relax.op.sigmoid), + "silu.default": self._unary_op(relax.op.nn.silu), + "sin.default": self._unary_op(relax.op.sin), + "sinh.default": self._unary_op(relax.op.sinh), + "softmax.int": self._softmax, + "sqrt.default": self._unary_op(relax.op.sqrt), + "tan.default": self._unary_op(relax.op.tan), + "tanh.default": self._unary_op(relax.op.tanh), + "tril.default": self._tril_triu(relax.op.tril), + "triu.default": self._tril_triu(relax.op.triu), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ec53cf23edc5..6f7c6fa2c575 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,64 +62,12 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _clamp(self, node: fx.Node) -> relax.Expr: - args = self.retrieve_args(node) - a_min = args[1] if len(args) > 1 else node.kwargs["min"] - a_max = args[2] if len(args) > 2 else node.kwargs["max"] - if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" - ) - if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" - ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) - - def _gelu(self, node: fx.Node) -> relax.Expr: - approximate = node.kwargs.get("approximate", "none") - if approximate == "none": - return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) - elif approximate == "tanh": - return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) - else: - raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) - - def _hardsigmoid(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) - - def _hardswish(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - x2 = relax.op.divide(x1, relax.const(6, dtype)) - return self.block_builder.emit(relax.op.multiply(x, x2)) - - def _leakyrelu(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) - return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _log_softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _log_softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -127,17 +75,6 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _round(self, node: fx.Node) -> relax.Expr: - if node.kwargs.get("decimals", 0) != 0: - raise ValueError("specifying decimals for round is not supported yet") - arg = self.env[node.args[0]] - return self.block_builder.emit(relax.op.round(arg)) - - def _softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - def _softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -159,17 +96,6 @@ def convert(node: fx.Node) -> relax.Var: return convert - def _tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) - assert isinstance(k, int) - return self.block_builder.emit(op(x, k)) - - return convert - ########## Binary Ops ########## def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 112390fe6094..6c17d96004b6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -39,6 +39,166 @@ def verify_model(torch_model, example_args, binding, expected): def test_unary(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + # acos + class Acos(Module): + def forward(self, input): + return torch.acos(input) + + @tvm.script.ir_module + class expected_acos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acos(), example_args, {}, expected_acos) + + # acosh + class Acosh(Module): + def forward(self, input): + return torch.acosh(input) + + @tvm.script.ir_module + class expected_acosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acosh(), example_args, {}, expected_acosh) + + # asin + class Asin(Module): + def forward(self, input): + return torch.asin(input) + + @tvm.script.ir_module + class expected_asin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asin(), example_args, {}, expected_asin) + + # asinh + class Asinh(Module): + def forward(self, input): + return torch.asinh(input) + + @tvm.script.ir_module + class expected_asinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asinh(), example_args, {}, expected_asinh) + + # atan + class Atan(Module): + def forward(self, input): + return torch.atan(input) + + @tvm.script.ir_module + class expected_atan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atan(), example_args, {}, expected_atan) + + # atanh + class Atanh(Module): + def forward(self, input): + return torch.atanh(input) + + @tvm.script.ir_module + class expected_atanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atanh(), example_args, {}, expected_atanh) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected_cos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cos(), example_args, {}, expected_cos) + + # cosh + class Cosh(Module): + def forward(self, input): + return torch.cosh(input) + + @tvm.script.ir_module + class expected_cosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cosh(), example_args, {}, expected_cosh) + # dropout class Dropout1(Module): def __init__(self): @@ -53,7 +213,7 @@ def forward(self, input): return torch.dropout(input, 0.5, train=True) @tvm.script.ir_module - class expected1: + class expected_dropout: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -64,8 +224,47 @@ def main( R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected1) - verify_model(Dropout2(), example_args, {}, expected1) + verify_model(Dropout1(), example_args, {}, expected_dropout) + verify_model(Dropout2(), example_args, {}, expected_dropout) + + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Exp(), example_args, {}, expected_exp) + + # neg + class Neg(Module): + def forward(self, input): + return -input + + @I.ir_module + class expected_neg: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.negative(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Neg(), example_args, {}, expected_neg) # relu class ReLU0(Module): @@ -81,7 +280,7 @@ def forward(self, input): return torch.nn.functional.relu(input) @tvm.script.ir_module - class expected: + class expected_relu: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -93,8 +292,502 @@ def main( R.output(gv) return gv - verify_model(ReLU0(), example_args, {}, expected) - verify_model(ReLU1(), example_args, {}, expected) + verify_model(ReLU0(), example_args, {}, expected_relu) + verify_model(ReLU1(), example_args, {}, expected_relu) + + # rsqrt + class Rsqrt(Module): + def forward(self, input): + return torch.rsqrt(input) + + @I.ir_module + class expected_rsqrt: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Rsqrt(), example_args, {}, expected_rsqrt) + + # sigmoid + class Sigmoid(Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, input): + return self.sigmoid(input) + + class Sigmoid2(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected_sigmoid: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sigmoid(), example_args, {}, expected_sigmoid) + verify_model(Sigmoid2(), example_args, {}, expected_sigmoid) + + # silu + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + class SiLU2(Module): + def forward(self, input): + return torch.nn.functional.silu(input) + + @tvm.script.ir_module + class expected_silu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(SiLU(), example_args, {}, expected_silu) + verify_model(SiLU2(), example_args, {}, expected_silu) + + # sin + class Sin(Module): + def forward(self, input: torch.Tensor): + return torch.sin(input) + + @tvm.script.ir_module + class expected_sin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sin(), example_args, {}, expected_sin) + + # sinh + class Sinh(Module): + def forward(self, input): + return torch.sinh(input) + + @tvm.script.ir_module + class expected_sinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sinh(), example_args, {}, expected_sinh) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected_sqrt: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sqrt(), example_args, {}, expected_sqrt) + + # tan + class Tan(Module): + def forward(self, input): + return torch.tan(input) + + @tvm.script.ir_module + class expected_tan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tan(), example_args, {}, expected_tan) + + # tanh + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected_tanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tanh(), example_args, {}, expected_tanh) + + +def test_clamp(): + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected_clamp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Clamp(), example_args, {}, expected_clamp) + + +def test_gelu(): + class Gelu(Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, input): + return self.gelu(input) + + class Gelu2(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected_gelu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Gelu(), example_args, {}, expected_gelu) + verify_model(Gelu2(), example_args, {}, expected_gelu) + + +def test_hardsigmoid(): + class Hardsigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardsigmoid() + + def forward(self, input): + return self.hs(input) + + class Hardsigmoid2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardsigmoid(input) + + @tvm.script.ir_module + class expected_hardsigmoid: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) + verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) + + +def test_hardswish(): + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() + + def forward(self, input): + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardswish(), example_args, {}, expected1) + verify_model(Hardswish2(), example_args, {}, expected1) + + +def test_hardtanh(): + class Hardtanh(torch.nn.Module): + def __init__(self): + super().__init__() + self.ht = torch.nn.Hardtanh() + + def forward(self, input): + return self.ht(input) + + class Hardtanh2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardtanh(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardtanh(), example_args, {}, expected1) + verify_model(Hardtanh2(), example_args, {}, expected1) + + +def test_leakyrelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class LeakyReLU0(Module): + def __init__(self): + super().__init__() + self.leakyrelu = torch.nn.LeakyReLU(0.02) + + def forward(self, input): + return self.leakyrelu(input) + + class LeakyReLU1(Module): + def forward(self, input): + return torch.nn.functional.leaky_relu(input, 0.02) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, 0.02) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LeakyReLU0(), example_args, {}, expected) + verify_model(LeakyReLU1(), example_args, {}, expected) + + +def test_logsoftmax(): + class LogSoftmax(Module): + def __init__(self): + super().__init__() + self.lsm = torch.nn.LogSoftmax(dim=1) + + def forward(self, input): + return self.lsm(input) + + class LogSoftmax2(Module): + def forward(self, input): + return torch.nn.functional.log_softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.log_softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LogSoftmax(), example_args, {}, expected1) + verify_model(LogSoftmax2(), example_args, {}, expected1) + + +def test_round(): + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Round(), example_args, {}, expected) + + +def test_softmax(): + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + class Softmax2(Module): + def forward(self, input): + return torch.nn.functional.softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softmax(), example_args, {}, expected1) + verify_model(Softmax2(), example_args, {}, expected1) + + +def test_tril_triu(): + example_args = (torch.randn(10, 10, dtype=torch.float32),) + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected_tril: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tril(), example_args, {}, expected_tril) + + class Triu(Module): + def forward(self, input): + return torch.triu(input, 1) + + @tvm.script.ir_module + class expected_triu: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Triu(), example_args, {}, expected_triu) def test_adaptive_avgpool2d(): diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index ae5172f6caf0..a8032ce0d26d 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -98,6 +98,7 @@ def test_lstm(target_precision): ) +@pytest.mark.skip(reason="Flaky test") def test_lstm_float64(): """Tests if can handle other mixed precision types. From 7c28c86f7d3121ce2adc179475fdb1922c86b942 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 22:30:15 +0900 Subject: [PATCH 178/202] [Relax][PyTorch] Support binary, statistical and search ops for ExportedProgram importer (#17424) * support binary ops * support mean * support sum * support argmax and argmin --- .../torch/base_fx_graph_translator.py | 62 +++ .../torch/exported_program_translator.py | 25 + .../tvm/relax/frontend/torch/fx_translator.py | 62 --- .../test_frontend_from_exported_program.py | 512 ++++++++++++++++++ 4 files changed, 599 insertions(+), 62 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d52b3d598f89..a41b9b6d4f9a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -185,6 +185,39 @@ def convert(node: fx.Node) -> relax.Var: return convert + ########## Binary Ops ########## + + def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return call_binary_op(relax_op, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + elif isinstance(rhs, relax.expr.Constant): + return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + return intrinsic_op(lhs, rhs) + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: @@ -283,6 +316,35 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + ########## Statistical ########## + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(op(x, dim, keepdim)) + + return convert + ########## Manipulation ########## def _reshape(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1ceddad7d79f..11594690cdc2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel """PyTorch ExportedProgram of Relax.""" from collections import ChainMap, OrderedDict +from functools import partial from typing import Callable, Dict, List, Tuple import torch @@ -76,6 +77,8 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: + import operator + return { # unary "acos.default": self._unary_op(relax.op.acos), @@ -109,11 +112,33 @@ def create_convert_map( "tanh.default": self._unary_op(relax.op.tanh), "tril.default": self._tril_triu(relax.op.tril), "triu.default": self._tril_triu(relax.op.triu), + # binary + "add.Tensor": self._binary_op(relax.op.add, operator.add), + "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), + "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), + "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), + "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "lt.Scalar": self._binary_op(relax.op.less, operator.lt), + "lt.Tensor": self._binary_op(relax.op.less, operator.lt), + "matmul.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "max.other": self._binary_op(relax.op.maximum, max), + "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul), + "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), + "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), + "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, + # statistical + "mean.dim": self._mean, + "sum.dim_IntList": self._sum, + # search + "argmax.default": self._argmax_argmin(relax.op.argmax), + "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "view.default": self._reshape, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6f7c6fa2c575..dc6ebc2eb34f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -96,39 +96,6 @@ def convert(node: fx.Node) -> relax.Var: return convert - ########## Binary Ops ########## - - def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - def promote_binary_op_args(lhs, rhs): - if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - return lhs, rhs - elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) - elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs - else: - assert False - - def call_binary_op(op, lhs, rhs): - lhs, rhs = promote_binary_op_args(lhs, rhs) - return self.block_builder.emit(op(lhs, rhs)) - - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return call_binary_op(relax_op, lhs, rhs) - elif isinstance(lhs, relax.expr.Constant): - return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) - elif isinstance(rhs, relax.expr.Constant): - return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) - return intrinsic_op(lhs, rhs) - - return convert - ########## Neural Network ########## def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: @@ -794,35 +761,6 @@ def _unbind(self, node: fx.Node) -> relax.Var: ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) - ########## Statistical ########## - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) - return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) - return self.block_builder.emit(op(x, dim, keepdim)) - - return convert - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6c17d96004b6..25e6dbfae308 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -790,6 +790,372 @@ def main( verify_model(Triu(), example_args, {}, expected_triu) +def test_binary(): + example_args1 = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) + + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected_add1: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="float32"), + rhs: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected_add2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Add1(), example_args1, {}, expected_add1) + verify_model(Add2(), example_args2, {}, expected_add2) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected_truediv1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected_truediv2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(TrueDiv1(), example_args1, {}, expected_truediv1) + verify_model(TrueDiv2(), example_args2, {}, expected_truediv2) + + # EQ + class EQ1(Module): + def forward(self, lhs, rhs): + return lhs == rhs + + @tvm.script.ir_module + class expected_eq1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + class EQ2(Module): + def forward(self, lhs): + return lhs == 1.0 + + @tvm.script.ir_module + class expected_eq2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + verify_model(EQ1(), example_args1, {}, expected_eq1) + verify_model(EQ2(), example_args2, {}, expected_eq2) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected_floordiv1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected_floordiv2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1) + verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected_lt1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected_lt2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + verify_model(LT1(), example_args1, {}, expected_lt1) + verify_model(LT2(), example_args2, {}, expected_lt2) + + # MatMul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected_matmul1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MatMul1(), example_args1, {}, expected_matmul1) + + # Max + class Max1(Module): + def forward(self, x, y): + return torch.max(x, y) + + @I.ir_module + class expected_max1: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32"), + inp_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0, inp_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Max1(), example_args1, {}, expected_max1) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected_mul1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected_mul2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Mul1(), example_args1, {}, expected_mul1) + verify_model(Mul2(), example_args2, {}, expected_mul2) + + # Power + class Power1(Module): + def forward(self, lhs, rhs): + return lhs**rhs + + @tvm.script.ir_module + class expected_power1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + + @tvm.script.ir_module + class expected_power2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Power1(), example_args1, {}, expected_power1) + verify_model(Power2(), example_args2, {}, expected_power2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected_sub1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected_sub2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sub1(), example_args1, {}, expected_sub1) + verify_model(Sub2(), example_args2, {}, expected_sub2) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1094,6 +1460,152 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_mean(): + class Mean(Module): + def forward(self, input): + return input.mean(-1) + + class MeanKeepDim(Module): + def forward(self, input: torch.Tensor): + return input.mean(-1, keepdim=True) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) + gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) + gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(256, 256, dtype=torch.float32),) + verify_model(Mean(), example_args, {}, Expected1) + verify_model(MeanKeepDim(), example_args, {}, Expected2) + + +def test_sum(): + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Sum(), example_args, {}, expected1) + + +def test_argmax_argmin(): + example_args = (torch.randn(256, 256, dtype=torch.float32),) + + class Argmax1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1) + + class Argmax2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1, keepdim=True) + + @tvm.script.ir_module + class expected_argmax1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256,), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) + gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_argmax2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) + gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,) + R.output(gv) + return gv + + verify_model(Argmax1(), example_args, {}, expected_argmax1) + verify_model(Argmax2(), example_args, {}, expected_argmax2) + + class Argmin1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input) + + class Argmin2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input, keepdim=True) + + @tvm.script.ir_module + class expected_argmin1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_argmin2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) + gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,) + R.output(gv) + return gv + + verify_model(Argmin1(), example_args, {}, expected_argmin1) + verify_model(Argmin2(), example_args, {}, expected_argmin2) + + def test_view(): class View(Module): def forward(self, x): From 7ff4d0d27dcde17b536b1f0429366d297493c250 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sat, 28 Sep 2024 06:30:29 -0700 Subject: [PATCH 179/202] [Web] Allow deprecated API requestAdapterInfo with any cast (#17420) * [Web] Allow deprectaed API with any cast * Fix lint * Fix by adding await --- web/package-lock.json | 4 ++-- web/package.json | 2 +- web/src/webgpu.ts | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/web/package-lock.json b/web/package-lock.json index 561ba770913f..751aaf2ef442 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.18.0-dev0", + "version": "0.18.0-dev2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.18.0-dev0", + "version": "0.18.0-dev2", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/package.json b/web/package.json index a4e5d7ac086d..a63997bb2f1c 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.18.0-dev0", + "version": "0.18.0-dev2", "files": [ "lib" ], diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index d3d431cf1f70..5b2d7c9f30a0 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -116,7 +116,9 @@ export async function detectGPUDevice(): Promise Date: Sun, 29 Sep 2024 06:59:33 +0900 Subject: [PATCH 180/202] [Relax][PyTorch] Support neural network ops for ExportedProgram importer (#17426) * support batchnorm2d and getitem * support addmm * support avg_pool2d * support baddbmm * support bmm * support conv_transpose1d * support conv_transpose2d * support conv1d * support conv3d * support einsum * support embedding * support group_norm * support layer_norm * support scaled_dot_product_attention * support unbind * support interpolate * fix lint error --- .../torch/base_fx_graph_translator.py | 464 +++++++ .../torch/exported_program_translator.py | 111 ++ .../tvm/relax/frontend/torch/fx_translator.py | 482 +------ .../test_frontend_from_exported_program.py | 1150 ++++++++++++++++- 4 files changed, 1723 insertions(+), 484 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a41b9b6d4f9a..52784dc8c3cd 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -227,6 +227,228 @@ def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + def _addmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) + return res + + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + + def _baddbmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + batch1 = self.env[node.args[1]] + batch2 = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(batch1, batch2)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + + def _conv_transpose1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + + def _conv_transpose1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv_transpose1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv_transpose2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + + def _conv_transpose2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv_transpose2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv2d_impl( self, x: relax.Expr, @@ -276,6 +498,134 @@ def _conv2d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv3d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) + + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -316,6 +666,39 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + transpose_S_H( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + ) + + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## def _mean(self, node: fx.Node) -> relax.Var: @@ -357,6 +740,87 @@ def _reshape(self, node: fx.Node) -> relax.Var: ########## Others ########## + def _getitem(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + take_indices = [] + take_axes = [] + stride_begin = [] + stride_end = [] + stride = [] + stride_axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + non_ellipsis_cnt = 0 + for index in node.args[1]: + if isinstance(index, (int, slice, torch.fx.Node)): + non_ellipsis_cnt += 1 + for index in node.args[1]: + if isinstance(index, int): + stride_begin.append(index) + stride_end.append(index + 1) + stride.append(1) + stride_axes.append(i) + i = i + 1 + elif isinstance(index, slice): + stride_begin.append(0 if index.start is None else index.start) + stride_end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + stride_axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(len(stride_axes) + len(expand_dim)) + elif index is Ellipsis: + for _ in range(len(shape) - non_ellipsis_cnt): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + elif isinstance(index, torch.fx.Node): + node_index = self.env[index] + if not isinstance(node_index, relax.Expr): + raise ValueError( + "Unsupported index type for relax.op.take: " + str(type(node_index)) + ) + take_indices.append(node_index) + take_axes.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + taken = x + if len(take_indices) > 1: + raise ValueError("Multiple tensors as index not yet supported") + for each_index, each_axis in zip(take_indices, take_axes): + taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) + sliced = self.block_builder.emit( + relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) + ) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) + else: + assert False + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 11594690cdc2..64583d750974 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -74,6 +74,94 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + ########## Neural Network ########## + + def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + + return self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + momentum=momentum, + ) + ) + + def _group_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_groups = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + + def _upsample_impl( + self, x: relax.Expr, size, align_corners: bool, scale_factor, method: str + ) -> relax.Var: + coord_trans = "align_corners" if align_corners else "half_pixel" + + if size is None: + shape = self.shape_of(x) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, (tuple, list)): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + return self.block_builder.emit( + relax.op.image.resize2d( + x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "linear") + + def _upsample_nearest2d(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -129,10 +217,31 @@ def create_convert_map( "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network + "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "addmm.default": self._addmm, + "avg_pool2d.default": self._avg_pool2d, + "baddbmm.default": self._baddbmm, + "bmm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "conv_transpose1d.default": self._conv_transpose1d, + "conv_transpose2d.input": self._conv_transpose2d, + "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, + "conv3d.default": self._conv3d, + "einsum.default": self._einsum, + "embedding.default": lambda node: self._embedding_impl( + self.env[node.args[1]], self.env[node.args[0]] + ), + "group_norm.default": self._group_norm, + "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, + "scaled_dot_product_attention.default": self._scaled_dot_product_attention, + "unbind.int": self._unbind, + "upsample_bilinear2d.vec": self._upsample_bilinear2d, + "upsample_nearest2d.vec": self._upsample_nearest2d, # statistical "mean.dim": self._mean, "sum.dim_IntList": self._sum, @@ -141,6 +250,8 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "view.default": self._reshape, + # other + "getitem": self._getitem, } def from_exported_program( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index dc6ebc2eb34f..c60c7c3953b4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union from functools import partial, reduce import tvm @@ -107,57 +107,6 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _addmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - y = self.env[node.args[1]] - z = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) - return res - - def _avg_pool2d_impl( - self, - x: relax.Expr, - kernel_size: Union[int, Tuple[int, int]] = (1, 1), - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[int] = 0, - ceil_mode: Optional[bool] = False, - ) -> relax.Var: - stride = kernel_size if stride is None or stride == [] else stride - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel_size, - strides=stride, - padding=padding, - ceil_mode=ceil_mode, - layout="NCHW", - ) - ) - - def _avg_pool2d(self, node: fx.Node) -> relax.Var: - args, kwargs = node.normalized_arguments(node) - x = self.env[args[0]] - kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] - stride = args[2] if len(args) > 2 else kwargs.get("stride", None) - padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) - ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) - return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -167,28 +116,6 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _baddbmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - a = self.env[node.args[1]] - b = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.matmul(a, b)) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) - return res - def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -212,63 +139,13 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _conv1d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv1d_transpose_impl( + return self._conv_transpose1d_impl( x, weight, bias=bias, @@ -278,63 +155,13 @@ def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv2d_transpose_impl( + return self._conv_transpose2d_impl( x, weight, bias=bias, @@ -344,55 +171,6 @@ def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -425,55 +203,6 @@ def _conv2d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv3d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) - - def _conv3d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -524,30 +253,6 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: ) ) - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.einsum(operands, args[0])) - - def _embedding_impl( - self, - x, - weight, - ) -> relax.Var: - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _embedding_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -655,61 +360,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var: ) ) - def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - if isinstance(normalized_shape, (immutable_list, tuple)): - normalized_shape = tuple(normalized_shape) - else: - try: - normalized_shape = self.env[normalized_shape] - except TypeError: - normalized_shape = tuple(normalized_shape) - - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - def _layer_norm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - normalized_shape = node.args[1] - gamma = self.env[node.args[2]] if len(node.args) > 2 else None - beta = self.env[node.args[3]] if len(node.args) > 3 else None - eps = node.args[4] if len(node.args) > 4 else 1e-05 - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - - def _layer_norm_module(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - normalized_shape = module.normalized_shape - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - eps = module.eps - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -728,39 +378,6 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) - dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) - assert dropout_p == 0.0, "Dropout is not supported" - is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) - causal_mask = "TopLeft" if is_causal else None - - if attn_mask is not None: - attn_mask = self.env[attn_mask] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.struct_info.dtype, msg - - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) - ) - - def _unbind(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - assert isinstance(dim, int), "Expected 2nd argument of unbind as int" - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -1054,87 +671,6 @@ def _getattr(self, node: fx.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _getitem(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): - return x[node.args[1]] - elif isinstance(x, relax.Var): - if isinstance(x.struct_info, relax.TupleStructInfo): - return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) - - assert isinstance(x.struct_info, relax.TensorStructInfo) - take_indices = [] - take_axes = [] - stride_begin = [] - stride_end = [] - stride = [] - stride_axes = [] - expand_dim = [] - i = 0 - shape = self.shape_of(x) - non_ellipsis_cnt = 0 - for index in node.args[1]: - if isinstance(index, (int, slice, torch.fx.Node)): - non_ellipsis_cnt += 1 - for index in node.args[1]: - if isinstance(index, int): - stride_begin.append(index) - stride_end.append(index + 1) - stride.append(1) - stride_axes.append(i) - i = i + 1 - elif isinstance(index, slice): - stride_begin.append(0 if index.start is None else index.start) - stride_end.append(shape[i] if index.stop is None else index.stop) - stride.append(1 if index.step is None else index.step) - stride_axes.append(i) - i = i + 1 - elif index is None: - expand_dim.append(len(stride_axes) + len(expand_dim)) - elif index is Ellipsis: - for _ in range(len(shape) - non_ellipsis_cnt): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - elif isinstance(index, torch.fx.Node): - node_index = self.env[index] - if not isinstance(node_index, relax.Expr): - raise ValueError( - "Unsupported index type for relax.op.take: " + str(type(node_index)) - ) - take_indices.append(node_index) - take_axes.append(i) - i = i + 1 - else: - raise ValueError("Unsupported index type: " + str(type(index))) - while i < len(shape): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - taken = x - if len(take_indices) > 1: - raise ValueError("Multiple tensors as index not yet supported") - for each_index, each_axis in zip(take_indices, take_axes): - taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) - sliced = self.block_builder.emit( - relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) - ) - sliced_shape = list(self.shape_of(sliced)) - for i in expand_dim: - sliced_shape.insert(i, 1) - return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) - elif isinstance(x, relax.Constant): - dtype = x.struct_info.dtype - return relax.const(x.data.numpy()[node.args[1]], dtype) - else: - assert False - def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1182,8 +718,8 @@ def create_convert_map( nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, - nn.ConvTranspose1d: self._conv1d_transpose_module, - nn.ConvTranspose2d: self._conv2d_transpose_module, + nn.ConvTranspose1d: self._conv_transpose1d_module, + nn.ConvTranspose2d: self._conv_transpose2d_module, nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, @@ -1248,8 +784,8 @@ def create_convert_map( "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose, - "conv_transpose2d": self._conv2d_transpose, + "conv_transpose1d": self._conv_transpose1d, + "conv_transpose2d": self._conv_transpose2d, "conv1d": self._conv1d, "conv2d": self._conv2d, "conv3d": self._conv3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 25e6dbfae308..7c887d9b9610 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1156,6 +1156,59 @@ def main( verify_model(Sub2(), example_args2, {}, expected_sub2) +def test_batchnorm2d(): + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = BatchNorm2d().eval() + binding = { + "w1": model.bn.weight.detach().numpy(), + "w2": model.bn.bias.detach().numpy(), + "w3": model.bn.running_mean.detach().numpy(), + "w4": model.bn.running_var.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1165,28 +1218,594 @@ def __init__(self): def forward(self, input): return self.pool(input) - class AdaptiveAvgPool2d1(Module): + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + + +def test_addmm(): + class Addmm1(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + class Addmm2(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) + lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + + verify_model(Addmm1(), example_args, {}, expected1) + verify_model(Addmm2(), example_args, {}, expected2) + + +def test_avg_pool2d(): + class AvgPool2d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.avg_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool2d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d( + input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + ceil_mode=True, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool2d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[2, 1], + strides=[2, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AvgPool2d1(), example_args, {}, expected1) + verify_model(AvgPool2d2(), example_args, {}, expected2) + verify_model(AvgPool2d3(), example_args, {}, expected2) + verify_model(AvgPool2d4(), example_args, {}, expected3) + + +def test_baddbmm(): + class BAddBMM1(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM2(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM3(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=3) + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + inp_0, R.const(3, "float32") + ) + lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 512, dtype=torch.float32), + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BAddBMM1(), + example_args, + {}, + Expected1, + ) + + verify_model( + BAddBMM2(), + example_args, + {}, + Expected2, + ) + + verify_model( + BAddBMM3(), + example_args, + {}, + Expected3, + ) + + +def test_bmm(): + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BMM(), + example_args, + {}, + Expected, + ) + + +def test_conv_transpose1d(): + class ConvTranspose1d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose1d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 6, 4, dtype=torch.float32),) + + model = ConvTranspose1d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_conv_transpose2d(): + class ConvTranspose2d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose2d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = ConvTranspose2d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_conv1d(): + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + def forward(self, input): - return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + return self.conv(input) @tvm.script.ir_module - class expected1: + class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + w1: R.Tensor((6, 3, 7), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( - input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,) R.output(gv) return gv - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) - verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + + model = Conv1D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) def test_conv2d(): @@ -1281,6 +1900,267 @@ def main( verify_model(model, example_args, binding, expected2) +def test_conv3d(): + class Conv3D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv3D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) + + model = Conv3D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum("i,j->ij", x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) + gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(4, 4, dtype=torch.float32),) + verify_model(Einsum1(), example_args, {}, Expected1) + + example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32)) + verify_model(Einsum2(), example_args, {}, Expected2) + + +def test_embedding(): + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), dtype=torch.int64),) + + model = Embedding() + binding = {"w1": model.embedding.weight.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = GroupNorm() + binding = { + "w1": model.gn.weight.detach().numpy(), + "w2": model.gn.bias.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + +def test_layernorm(): + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = LayerNorm() + binding = { + "w1": model.ln.weight.detach().numpy(), + "w2": model.ln.bias.detach().numpy(), + } + verify_model(LayerNorm(), example_args, binding, expected1) + + def test_linear(): class Dense1(Module): def __init__(self): @@ -1460,6 +2340,254 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_scaled_dot_product_attention(): + class Attention1(Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + class Attention2(Module): + def forward(self, q, k, v, mask): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask) + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, inp_3, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + verify_model( + Attention1(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + ), + {}, + Expected1, + ) + + verify_model( + Attention2(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 128, dtype=torch.float32), + ), + {}, + Expected2, + ) + + +def test_unbind(): + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + + +def test_interpolate(): + class InterpolateBilinear(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="bilinear") + + @tvm.script.ir_module + class expected_bilinear: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class InterpolateNearest(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="nearest") + + @tvm.script.ir_module + class expected_nearest: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) + verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) + verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + + def test_mean(): class Mean(Module): def forward(self, input): From e80801030ebafa38195666962d3fb79b2e433616 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 30 Sep 2024 18:36:41 +0530 Subject: [PATCH 181/202] [DLIGHT][GPU] Improve matmul schedule for adreno (#17430) Improved matmul schedule with layout transpose approach, which improves as follows - ----Model-------prefill baseline ---------prefill optimized --Llama-2-7b-------51 tok/sec --------------86 tok/sec --Llama-3-8b-------48 tok/sec --------------79 tok/sec --gemma-2b -------140 tok/sec -------------245 tok/sec --------- --- python/tvm/dlight/gpu/matmul.py | 108 ++++++++------ tests/python/dlight/test_gpu_matmul.py | 196 +++++++++++++++---------- 2 files changed, 178 insertions(+), 126 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 5568083982b9..d9d4b7ebd4d2 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -26,6 +26,7 @@ from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV +from tvm.script import tir as T from ..base import analysis, BlockInfo, IterInfo from .base import GPUScheduleRule @@ -945,14 +946,14 @@ def get_configs(self, target: Target) -> Config: ): return Matmul.Config( block_size_x=32, - block_size_y=8, + block_size_y=4, vthread_x=1, vthread_y=1, micro_size_x=8, micro_size_y=2, micro_size_k=16, vector_size=8, - unroll=4, + unroll=16, use_shared=False, storage_align=False, inner_x=True, @@ -1147,7 +1148,7 @@ def get_max_factor(n, factors): if not ( isinstance(sch.get(n).extent, tir.IntImm) and isinstance(sch.get(mb).extent, tir.IntImm) - and isinstance(sch.get(ms).extent, tir.Var) + and not isinstance(sch.get(ms).extent, tir.IntImm) ): return None @@ -1157,6 +1158,7 @@ def get_max_factor(n, factors): config.vector_size, config.unroll, ) + VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize) dequant_block = None matmul_block = reduction_block @@ -1169,61 +1171,73 @@ def get_max_factor(n, factors): elif blk is not matmul_block: sch.compute_inline(blk) - m = sch.fuse(mb, ms) - - sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1]) - - rmat_block, wmat_block = ( + block = sch.reindex(reduction_block, ("read", 0)) + sch.pad_einsum(reduction_block, [1, Unroll_M, 1, 1]) + sch.compute_inline(block) + trans_block, matmul_reindex = ( sch.get_producers(matmul_block)[0], sch.get_consumers(matmul_block)[0], ) - mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M]) - no, ni, nv = sch.split(n, [None, Threads_X, VecSize]) - k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8]) - sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv) - sch.compute_at(rmat_block, k0) - if dequant_block is not None: - sch.compute_at(dequant_block, k3) - sch.reverse_compute_at(wmat_block, mi) - sch.set_scope(rmat_block, 0, "shared") - sch.set_scope(matmul_block, 0, "local") + if epilogue_block is not None: + sch.compute_inline(matmul_reindex) + matmul_reindex = epilogue_block - if dequant_block is not None: - sch.set_scope(dequant_block, 0, "local") + sch.transform_layout( + trans_block, + ("write", 0), + T.index_map(lambda i0, i1, i2: (i0, i1 // Unroll_M, i2, i1 % Unroll_M)), + ) - sch.bind(mo, "blockIdx.y") - sch.bind(no, "blockIdx.x") - sch.bind(mi, "threadIdx.y") - sch.bind(ni, "threadIdx.x") - sch.vectorize(sch.get_loops(matmul_block)[-1]) + # transpose block schedules + # sch.set_scope(trans_block, 0, "global.texture-1d") + tb, tn, tk = sch.get_loops(trans_block) + tbx, ttx = sch.split(tk, [None, Threads_X]) + tby, tty, tc = sch.split(tn, [None, Threads_Y, Unroll_M]) + sch.bind(tb, "blockIdx.z") + sch.bind(tby, "blockIdx.y") + sch.bind(tbx, "blockIdx.x") + sch.bind(tty, "threadIdx.y") + sch.bind(ttx, "threadIdx.x") + sch.reorder(tb, tby, tbx, tty, ttx, tc) + sch.vectorize(tc) + + mb, ms, n, k = sch.get_loops(matmul_block) + m = sch.fuse(mb, ms) + bx, tx, vec = sch.split(n, [None, Threads_X, VecSize]) + by, ty, unr = sch.split(m, [None, Threads_Y, Unroll_M]) + k1, k2, k3 = sch.split(k, [None, 4, 8]) + sch.reorder(bx, by, tx, ty, k1, k2, k3, unr, vec) + sch.set_scope(matmul_block, 0, "local") if dequant_block is not None: - sch.vectorize(sch.get_loops(dequant_block)[-1]) + sch.compute_at(dequant_block, k3) + sch.set_scope(dequant_block, 0, "local") + sch.bind(by, "blockIdx.y") + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) - # Co-operative Memory Fetch - ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize]) - sch.bind(ro, "threadIdx.x") - sch.vectorize(rv) + inp = sch.cache_read(matmul_block, read_buffer_index=0, storage_scope="local") + sch.compute_at(inp, k3, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(inp)[-1]) - wv = sch.get_loops(wmat_block)[-1] - sch.vectorize(wv) + sch.unroll(unr) + sch.unroll(k3) - # Scale and Quant Cache if dequant_block is not None: - qb = sch.cache_read(dequant_block, 0, "local") - sb = sch.cache_read(dequant_block, 1, "local") - sch.compute_at(sb, k1) - sch.compute_at(qb, k2) - sch.set_scope(sb, 0, "local") - sch.set_scope(qb, 0, "local") - sch.vectorize(sch.get_loops(qb)[-1]) - sch.vectorize(sch.get_loops(sb)[-1]) + Aq_local = sch.cache_read(dequant_block, read_buffer_index=0, storage_scope="local") + sch.compute_at(Aq_local, k2, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(Aq_local)[-1]) + As_local = sch.cache_read(dequant_block, read_buffer_index=1, storage_scope="local") + sch.compute_at(As_local, k1, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(As_local)[-1]) + sch.vectorize(sch.get_loops(dequant_block)[-1]) - if epilogue_block is not None: - sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True) - sch.set_scope(wmat_block, 0, "local") - sch.compute_inline(wmat_block) - sch.vectorize(sch.get_loops(epilogue_block)[-1]) + sch.reverse_compute_at(matmul_reindex, ty) + o_ur, o_vec = sch.get_loops(matmul_reindex)[-2:] + sch.vectorize(o_vec) + sch.unroll(o_ur) + sch.decompose_reduction(matmul_block, k1) - sch.decompose_reduction(matmul_block, k0) return sch diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index dc5276e62a5f..83b52efc3a69 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -634,49 +634,68 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) # with T.block("root"): - inp0_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared") - matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") + inp0_reindex_pad = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16))) + matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local") + inp0_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), scope="local") + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"): + for i1_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): + for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i1_2 in T.vectorized(T.int64(16)): + with T.block("inp0_reindex_pad"): + v0 = T.axis.spatial(T.int64(1), i0) + v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) + v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) + T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(inp0[v0, v1, v2]) + T.writes(inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)]) + inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"): - for i0_i1_fused_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i0_i1_fused_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_fused_2_init in range(T.int64(4)): + for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (m + T.int64(15)) // T.int64(16) * T.int64(16)) T.reads() T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) matmul_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0 in range(T.int64(4)): - for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax1_1 in T.vectorized(T.int64(8)): - with T.block("inp0_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(inp0[v0, v1, v2]) - T.writes(inp0_pad_shared[v0, v1, v2]) - inp0_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) - for k_1, k_2, k_3, i0_i1_fused_2 in T.grid(T.int64(8), T.int64(4), T.int64(8), T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) - v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_pad_shared[v_i0, v_i1, v_k], inp1[v_k, v_i2]) - T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) - matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_pad_shared[v_i0, v_i1, v_k] * inp1[v_k, v_i2] - for ax0 in range(T.int64(4)): + for k_0, k_1 in T.grid(T.int64(128), T.int64(4)): + for k_2 in T.unroll(T.int64(8)): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3 in T.vectorized(T.int64(16)): + with T.block("inp0_reindex_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) + v3 = T.axis.spatial(T.int64(16), ax3) + T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (m + T.int64(15)) // T.int64(16)) + T.reads(inp0_reindex_pad[v0, v1, v2, v3]) + T.writes(inp0_reindex_pad_local[v0, v1, v2, v3]) + inp0_reindex_pad_local[v0, v1, v2, v3] = inp0_reindex_pad[v0, v1, v2, v3] + for i0_i1_fused_2 in T.unroll(T.int64(16)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], inp1[v_k, v_i2]) + T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) + matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * inp1[v_k, v_i2] + for ax0 in T.unroll(T.int64(16)): for ax1 in T.vectorized(T.int64(8)): with T.block("matmul_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 - (m + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < m) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) T.reads(matmul_pad_local[v0, v1, v2]) T.writes(matmul[v0, v1, v2]) matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2] @@ -729,75 +748,94 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") - rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") - matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local") + rms_norm130_reindex_pad = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16") + matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(12288)), "float16", scope="local") + rms_norm130_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16", scope="local") lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"): + for i1_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): + for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i1_2 in T.vectorized(T.int64(16)): + with T.block("rms_norm130_reindex_pad"): + v0 = T.axis.spatial(T.int64(1), i0) + v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) + v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) + T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(rms_norm130[v0, v1, v2]) + T.writes(rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)]) + rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): - for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i0_i1_fused_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_fused_2_init in range(T.int64(4)): + for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) T.reads() T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(16)): - for ax0 in range(T.int64(4)): - for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax1_1 in T.vectorized(T.int64(8)): - with T.block("rms_norm130_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(rms_norm130[v0, v1, v2]) - T.writes(rms_norm130_pad_shared[v0, v1, v2]) - rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) - for k_1 in range(T.int64(8)): - for ax0 in T.vectorized(T.int64(8)): + for k_0 in range(T.int64(128)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): with T.block("lv453_local"): - v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1) - v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + v0 = T.axis.spatial(T.int64(128), k_0 + ax0) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv453[v0, v1]) T.writes(lv453_local[v0, v1]) lv453_local[v0, v1] = lv453[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in T.vectorized(T.int64(8)): + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): with T.block("lv452_local"): - v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(4) + k_1 + ax0) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv452[v0, v1]) T.writes(lv452_local[v0, v1]) lv452_local[v0, v1] = lv452[v0, v1] - for k_3 in range(T.int64(8)): - for ax0 in T.vectorized(T.int64(8)): - with T.block("dequantize"): - v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) - T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) - dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] - for i0_i1_fused_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) - v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) - T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): + for k_2 in T.unroll(T.int64(8)): + for ax0 in T.vectorized(T.int64(8)): + with T.block("dequantize"): + v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) + T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) + dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3 in T.vectorized(T.int64(16)): + with T.block("rms_norm130_reindex_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) + v3 = T.axis.spatial(T.int64(16), ax3) + T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (seq_len + T.int64(15)) // T.int64(16)) + T.reads(rms_norm130_reindex_pad[v0, v1, v2, v3]) + T.writes(rms_norm130_reindex_pad_local[v0, v1, v2, v3]) + rms_norm130_reindex_pad_local[v0, v1, v2, v3] = rms_norm130_reindex_pad[v0, v1, v2, v3] + for i0_i1_fused_2 in T.unroll(T.int64(16)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], dequantize_intermediate_intermediate_local[v_k, v_i2]) + T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * dequantize_intermediate_intermediate_local[v_k, v_i2] + for ax0 in T.unroll(T.int64(16)): + for ax1 in T.vectorized(T.int64(8)): with T.block("T_add"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1) - v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2) - T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len) + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) + v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] From 7569148c3c5fbf3a9f4e65f80488434b6c4bcb84 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 30 Sep 2024 21:07:01 +0800 Subject: [PATCH 182/202] [Relax] Introduce static shape tuning pipeline (#17428) This PR introduces a static shape tuning pipeline for Relax. It is designed to work with the MetaSchedule tuning framework to optimize the performance of the model. Together with a minor typo fix --- docs/how_to/tutorials/e2e_opt_model.py | 16 +--------- python/tvm/relax/pipeline.py | 39 +++++++++++++++++++++++++ python/tvm/relax/transform/transform.py | 5 ++-- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 0053d309d5a9..5c11439e1635 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -101,21 +101,7 @@ # Skip running in CI environment IS_IN_CI = os.getenv("CI", "") == "true" if not IS_IN_CI: - with target: - mod = tvm.ir.transform.Sequential( - [ - # Convert BatchNorm into a sequence of simpler ops for fusion - relax.transform.DecomposeOpsForInference(), - # Canonicalize the bindings - relax.transform.CanonicalizeBindings(), - # Run default optimization pipeline - relax.get_pipeline("zero"), - # Tune the model and store the log to database - relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS), - # Apply the database - relax.transform.MetaScheduleApplyDatabase(work_dir), - ] - )(mod) + mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod) # Only show the main function mod["main"].show() diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 38242ff4d2d3..582f5111aaf5 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -21,6 +21,7 @@ as it is or serves as a basis to do further composition. """ # pylint: disable=unused-argument +from typing import Union import tvm from tvm import meta_schedule as ms @@ -104,10 +105,48 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I return _pipeline +def static_shape_tuning_pipeline( + total_trials: int, + target: Union[str, tvm.target.Target], + work_dir: str = "tuning_logs", +): + """Tune the static shape model and store the log to database. + + Parameters + ---------- + total_trials : int + Total number of trials to run. + + target : Union[str, tvm.target.Target] + The target device to tune the model. + + work_dir : str + The directory to store the tuning logs. + """ + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + with tvm.target.Target(target): + mod = tvm.transform.Sequential( + [ + transform.DecomposeOpsForInference(), + transform.CanonicalizeBindings(), + zero_pipeline(), + transform.MetaScheduleTuneIRMod({}, work_dir, total_trials), + transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + + return mod + + return _pipeline + + # global map of pre-built pipelines PIPELINE_MAP = { "zero": zero_pipeline, "default_build": default_build_pipeline, + "static_shape_tuning": static_shape_tuning_pipeline, } diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95649f331f33..3330d4098734 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1020,14 +1020,13 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor ---------- param_tuple_name: Optional[str] - The name of the tuple parameter. If unspecified, defaults to + The name of the tuple parameter. If unspecified, defaults to "model_params". Returns ------- ret : tvm.transform.Pass - The registered pass for lifting transformation of parameters. - + The registered pass for bundling model parameters. """ return _ffi_api.BundleModelParams(param_tuple_name) # type: ignore From 4f948901124761ce27dba4f0e4b752480315893c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 30 Sep 2024 08:47:36 -0700 Subject: [PATCH 183/202] [NVSHMEM] Enable nvshmem memory allocation (#17415) This PR add the support of nvshmem memory allocation, and integrates it into disco. --- .../contrib/nvshmem/{nvshmem.cc => init.cc} | 2 + .../contrib/nvshmem/memory_allocator.cc | 104 ++++++++++++++++++ tests/python/disco/test_nvshmem.py | 45 +++++++- 3 files changed, 145 insertions(+), 6 deletions(-) rename src/runtime/contrib/nvshmem/{nvshmem.cc => init.cc} (96%) create mode 100644 src/runtime/contrib/nvshmem/memory_allocator.cc diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc b/src/runtime/contrib/nvshmem/init.cc similarity index 96% rename from src/runtime/contrib/nvshmem/nvshmem.cc rename to src/runtime/contrib/nvshmem/init.cc index 985ba5510762..50fdde4c49d8 100644 --- a/src/runtime/contrib/nvshmem/nvshmem.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { } nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + CUDA_CALL(cudaSetDevice(mype_node)); LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " << ", npes=" << nvshmem_n_pes(); } diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc new file mode 100644 index 000000000000..89d56ed3dc81 --- /dev/null +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../cuda/cuda_common.h" +#include "../../memory/pooled_allocator.h" + +namespace tvm { +namespace runtime { + +using tvm::runtime::memory::Buffer; +using tvm::runtime::memory::PooledAllocator; + +/*! + * \brief The memory allocator of NVSHMEM. + * Overriding PooledAllocator for efficient memory management. + */ +class NVSHMEMAllocator final : public PooledAllocator { + public: + explicit NVSHMEMAllocator() : PooledAllocator() {} + + ~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); } + + void Clear() final { PooledAllocator::ReleaseAll(); } + + bool AllowMemoryScope(const std::string& mem_scope) const final { + // The allowed memory scope of NVSHMEM is "nvshmem"; + return mem_scope == "nvshmem"; + } + + /*! \brief Return the global NVSHMEM singleton allocator. */ + static NVSHMEMAllocator* Global() { + static NVSHMEMAllocator* allocator = new NVSHMEMAllocator(); + return allocator; + } + + NDArray Empty(ShapeTuple shape, DataType dtype, Device device) { + NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device); + container->SetDeleter([](Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + NVSHMEMAllocator::Global()->Free(*(buffer)); + delete buffer; + delete ptr; + }); + Buffer* buffer = new Buffer; + *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return NDArray(GetObjectPtr(container)); + } + + private: + void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment, + DLDataType type_hint) final { + ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA) + << "nvshmem can only allocate cuda device memory space."; + ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == DLDataTypeCode::kDLUInt || + type_hint.code == DLDataTypeCode::kDLFloat) + << "nvshmem can only allocate tensor with int, usingned int or float data types."; + return nvshmem_align(alignment, size); + } + + void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } +}; + +NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) { + return NVSHMEMAllocator::Global()->Empty(shape, dtype, device); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); + +void NVSHMEMFinalize() { + NVSHMEMAllocator::Global()->Clear(); + nvshmem_finalize(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); + +} // namespace runtime +} // namespace tvm diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index 0b16fe93612f..b304d145aa38 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -23,6 +23,9 @@ import subprocess import threading import sys +from multiprocessing import Process +from typing import Any, Callable, List + import tvm import tvm.testing @@ -82,8 +85,6 @@ def start_server(): thread.join() def __del__(self): - for node in self.remote_nodes: - node.kill() if self.sess is not None: self.sess.shutdown() del self.sess @@ -98,17 +99,49 @@ def create_socket_session(num_workers): return _SOCKET_SESSION_TESTER.sess -@pytest.mark.parametrize("num_workers", [2, 4]) -def test_nvshmem_init(num_workers): +def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int): if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: return - sess = create_socket_session(num_workers=num_workers) + + sess = session_kind(num_workers=num_workers) f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") uid = f_init_nvshmem_uid() init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") init_dfunc(uid, num_workers) sess.sync_worker_0() + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() + + +def test_nvshmem_empty(session_kind: di.Session, num_workers: int): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + + device = tvm.cuda() + sess = session_kind(num_workers=num_workers) + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers) + sess.sync_worker_0() + empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty") + a = empty_dfunc(ShapeTuple((32, 64)), "float32", device) + b = empty_dfunc(ShapeTuple((64, 32)), "float32", device) + sess.sync_worker_0() + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() if __name__ == "__main__": - tvm.testing.main() + # After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init` + # or `nvshmem_init_thread` in the same program results in undefined behavior. + # So we always create a new process to run the test. Then no repeated nvshmem + # init happens in the same process, since the worker0 may share the same process. + for session_kind in [create_socket_session, di.ProcessSession]: + for num_workers in [2, 4]: + for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]: + p = Process(target=test_func, args=[session_kind, num_workers]) + p.start() + p.join() From fab67a9af918607542d8e6a895d53cc28030d7bd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 2 Oct 2024 09:33:01 +0900 Subject: [PATCH 184/202] [Relax][PyTorch] Support tensor manipulation and creation ops for ExportedProgram importer (#17429) * support cat and concat * support cumsum * support expand * support permute * support squeeze * support tile * support transpose * support unsqueeze * add test for flatten * support repeat * add test for reshape * support select and slice * support arange * support empty * support fill * support new_ones * support _to_copy * support split * add test for unbind * support clone --- .../torch/base_fx_graph_translator.py | 161 ++++ .../torch/exported_program_translator.py | 39 + .../tvm/relax/frontend/torch/fx_translator.py | 139 ---- .../test_frontend_from_exported_program.py | 781 ++++++++++++++++++ 4 files changed, 981 insertions(+), 139 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 52784dc8c3cd..322ee04e0c20 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -730,6 +730,51 @@ def convert(node: fx.Node): ########## Manipulation ########## + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + + def _cumsum(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") + + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + sizes = args[1:] if len(args) > 2 else args[1] + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(sizes): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.permute_dims(x, dims)) + + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -738,6 +783,122 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _split(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _squeeze(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + return self.block_builder.emit(relax.op.squeeze(x, dim)) + + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + + def _transpose(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + ########## Creation ########## + + def _to_copy(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = self._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = self._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + def _arange(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = self._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.Node) -> relax.Var: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + + def _fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 64583d750974..1401a0bcef3a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -162,6 +162,22 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + ########## Manipulation ########## + + def _select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = relax.const(node.args[2], "int64") + return self.block_builder.emit(relax.op.take(x, index, dim)) + + def _slice(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + axes = [node.args[1]] + begin = [node.args[2]] + end = [node.args[3]] + stride = [node.args[4] if len(node.args) > 4 else 1] + return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -249,7 +265,30 @@ def create_convert_map( "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation + "cat.default": self._cat, + "concat.default": self._cat, + "cumsum.default": self._cumsum, + "expand.default": self._expand, + "permute.default": self._permute, + "repeat.default": self._repeat, + "select.int": self._select, + "slice.Tensor": self._slice, + "split.Tensor": self._split, + "squeeze.default": self._squeeze, + "squeeze.dim": self._squeeze, + "tile.default": self._tile, + "transpose.int": self._transpose, + "unsqueeze.default": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), "view.default": self._reshape, + # tensor creation + "_to_copy.default": self._to_copy, + "arange.start": self._arange, + "clone.default": lambda node: self.env[node.args[0]], + "empty.memory_format": self._empty, + "fill.Scalar": self._fill, + "new_ones.default": self._new_ones, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c60c7c3953b4..9fbc95fa7c00 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -380,41 +380,12 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: ########## Manipulation ########## - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") - - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - - def _expand(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - sizes = args[1:] if len(args) > 2 else args[1] - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(sizes): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: shape = self.shape_of(x) start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim @@ -440,22 +411,6 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.permute_dims(x, dims)) - - def _repeat(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -466,87 +421,8 @@ def _size(self, node: fx.Node) -> relax.Expr: idx = node.args[1] return self.shape_of(x)[idx].value - def _split(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - split_size = node.args[1] - dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - return self.block_builder.emit(relax.op.squeeze(x, dim)) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - - def _transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - full_idx = list(range(len(self.shape_of(args[0])))) - full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] - return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - ########## Creation ########## - def _arange(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] - - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] - - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 - - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = self._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - - def _empty(self, node: fx.Node) -> relax.Var: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) - def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -596,21 +472,6 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7c887d9b9610..65890ff6971b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2734,6 +2734,582 @@ def main( verify_model(Argmin2(), example_args, {}, expected_argmin2) +def test_cat_concat(): + class Cat0(Module): + def forward(self, x, y): + return torch.cat((x, y)) + + class Cat1(Module): + def forward(self, x, y): + return torch.cat((x, y), dim=1) + + class Cat2(Module): + def forward(self, x, y): + return torch.cat((x, y), 1) + + class Cat3(Module): + def forward(self, x, y): + return torch.concat((x, y), dim=0) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, inp_1), axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, inp_1), axis=1) + gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) + verify_model(Cat0(), example_args, {}, Expected1) + verify_model(Cat1(), example_args, {}, Expected2) + verify_model(Cat2(), example_args, {}, Expected2) + verify_model(Cat3(), example_args, {}, Expected1) + + +def test_cumsum(): + class Cumsum(Module): + def forward(self, input): + return torch.cumsum(input, dim=1, dtype=torch.int32) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Cumsum(), example_args, {}, expected1) + + +def test_expand(): + class Expand1(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Expand1(), example_args, {}, expected1) + verify_model(Expand2(), example_args, {}, expected1) + + +def test_flatten(): + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Flatten(), example_args, {}, expected1) + + +def test_permute(): + class Permute1(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + class Permute2(Module): + def forward(self, x): + return torch.permute(x, (0, 3, 2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Permute1(), example_args, {}, expected1) + verify_model(Permute2(), example_args, {}, expected1) + + +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + +def test_reshape(): + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Reshape(), example_args, {}, expected1) + + +def test_select_slice(): + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, R.const(0, "int64"), axis=0) + lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(9223372036854775807),), + (R.prim_value(2),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv1, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + @I.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((8, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( + x, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = R.expand_dims(lv1, axis=[2]) + lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.expand_dims(lv3, axis=[4]) + gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Slice1(), example_args, {}, expected1) + + example_args = (torch.randn(8, 16, dtype=torch.float32),) + verify_model(Slice2(), example_args, {}, expected2) + + +def test_split(): + class Chunk(Module): + def forward(self, input): + return torch.chunk(input, 3, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] + lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = (lv1, lv2, lv3) + R.output(gv) + return gv + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Chunk(), example_args, {}, Expected) + + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + + +def test_squeeze(): + class Squeeze1(Module): + def forward(self, input): + return input.squeeze(1) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) + gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Squeeze2(Module): + def forward(self, input): + return input.squeeze() + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) + + verify_model(Squeeze1(), example_args, {}, Expected1) + verify_model(Squeeze2(), example_args, {}, Expected2) + + +def test_tile(): + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + verify_model(Tile2(), example_args, {}, expected2) + verify_model(Tile3(), example_args, {}, expected2) + + +def test_transpose(): + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Transpose(), example_args, {}, expected1) + + +def test_unsqueeze(): + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + verify_model(Unsqueeze1(), example_args, {}, expected1) + verify_model(Unsqueeze2(), example_args, {}, expected2) + + def test_view(): class View(Module): def forward(self, x): @@ -2756,6 +3332,211 @@ def main( verify_model(View(), example_args, {}, expected1) +def test_arange(): + class Arange(Module): + def forward(self, input): + return torch.arange(0, 20, dtype=torch.int32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((20,), dtype="int32")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") + gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Arange(), example_args, {}, Expected) + + +def test_clone(): + class Clone(Module): + def forward(self, input): + return torch.clone(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Clone(), example_args, {}, Expected) + + +def test_empty(): + class Empty(Module): + def forward(self, input): + return torch.empty((10, 10), dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.zeros( + R.shape([10, 10]), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Empty(), example_args, {}, Expected) + + +def test_fill(): + class Fill(Module): + def forward(self, input: torch.Tensor): + return torch.fill(input, 1.5) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Fill(), example_args, {}, Expected) + + +def test_new_ones(): + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) + verify_model(NewOnes(), example_args, {}, expected1) + + +def test_to_copy(): + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected_float: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected_half: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + # type + class Type(Module): + def forward(self, x): + return x.type(torch.float32) + + @tvm.script.ir_module + class expected_type: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,) + R.output(gv) + return gv + + class To1(Module): + def forward(self, input): + return input.to(torch.float16) + + @I.ir_module + class expected_to1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + class To2(Module): + def forward(self, input): + return input.to("cpu") + + @I.ir_module + class expected_to2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(ToFloat(), example_args, {}, expected_float) + verify_model(ToHalf(), example_args, {}, expected_half) + verify_model(Type(), example_args, {}, expected_type) + verify_model(To1(), example_args, {}, expected_to1) + verify_model(To2(), example_args, {}, expected_to2) + + def test_keep_params(): class Conv2D1(Module): def __init__(self): From 5298b1298a8bb9166ef99dedef9979f2719c2416 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 2 Oct 2024 22:29:48 +0900 Subject: [PATCH 185/202] [CI] Upgrade CI (#17425) * upgrade ci --- docker/Dockerfile.ci_arm | 12 +- docker/Dockerfile.ci_cortexm | 6 +- docker/Dockerfile.ci_cpu | 12 +- docker/Dockerfile.ci_gpu | 4 +- docker/Dockerfile.ci_hexagon | 4 +- docker/Dockerfile.ci_i386 | 2 +- docker/Dockerfile.ci_lint | 4 +- docker/Dockerfile.ci_minimal | 4 +- docker/Dockerfile.ci_riscv | 4 +- docker/Dockerfile.ci_wasm | 4 +- docker/Dockerfile.demo_android | 4 +- docker/Dockerfile.demo_rocm | 4 +- docker/Dockerfile.demo_vitis_ai | 4 +- docker/install/ubuntu2004_install_python.sh | 8 +- docker/install/ubuntu_install_cmake_source.sh | 32 +- docker/install/ubuntu_install_jax.sh | 18 +- .../ubuntu_install_llvm_from_source.sh | 2 +- docker/install/ubuntu_install_python.sh | 54 +- docker/install/ubuntu_install_spike_sim.sh | 68 +- docker/install/ubuntu_install_tensorflow.sh | 4 +- .../ubuntu_install_tensorflow_aarch64.sh | 4 +- docker/install/ubuntu_install_tflite.sh | 40 +- docker/install/ubuntu_install_verilator.sh | 18 +- docker/install/ubuntu_install_zephyr.sh | 6 +- docker/python/bootstrap/generate.sh | 9 +- .../bootstrap/lockfiles/constraints-3.9.txt | 588 ++++++++++++++++++ .../bootstrap/lockfiles/requirements-3.9.txt | 3 + docs/how_to/dev/setup_rpc_system.rst | 4 +- python/tvm/tir/schedule/schedule.py | 9 +- 29 files changed, 764 insertions(+), 171 deletions(-) create mode 100644 docker/python/bootstrap/lockfiles/constraints-3.9.txt create mode 100644 docker/python/bootstrap/lockfiles/requirements-3.9.txt diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index f18d95daacec..2be887079e34 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -53,10 +53,10 @@ ENV PATH /opt/sccache:$PATH COPY install/ubuntu2204_install_llvm.sh /install/ubuntu2204_install_llvm.sh RUN bash /install/ubuntu2204_install_llvm.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. @@ -71,14 +71,6 @@ RUN bash /install/ubuntu_install_tensorflow_aarch64.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh -# Caffe deps -COPY install/ubuntu_install_boost.sh /install/ubuntu_install_boost.sh -RUN bash /install/ubuntu_install_boost.sh - -# Caffe -COPY install/ubuntu_install_caffe.sh /install/ubuntu_install_caffe.sh -RUN bash /install/ubuntu_install_caffe.sh - # ONNX COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh diff --git a/docker/Dockerfile.ci_cortexm b/docker/Dockerfile.ci_cortexm index 0a898e70581e..8006b27e84c2 100644 --- a/docker/Dockerfile.ci_cortexm +++ b/docker/Dockerfile.ci_cortexm @@ -30,15 +30,15 @@ COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh -RUN bash /install/ubuntu_install_cmake_source.sh 3.20.0 +RUN bash /install/ubuntu_install_cmake_source.sh 3.20.0 9c06b2ddf7c337e31d8201f6ebcd3bba86a9a033976a9aee207fe0c6971f4755 COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 17344f7dac22..37c7c9085714 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -34,10 +34,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. @@ -109,14 +109,6 @@ RUN bash /install/ubuntu_install_jax.sh "cpu" COPY install/ubuntu_download_arm_compute_lib_binaries.sh /install/ubuntu_download_arm_compute_lib_binaries.sh RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh -# Caffe deps -COPY install/ubuntu_install_boost.sh /install/ubuntu_install_boost.sh -RUN bash /install/ubuntu_install_boost.sh - -# Caffe -COPY install/ubuntu_install_caffe.sh /install/ubuntu_install_caffe.sh -RUN bash /install/ubuntu_install_caffe.sh - # Github Arm(R) Ethos(TM)-N NPU driver COPY install/ubuntu_install_ethosn_driver_stack.sh /install/ubuntu_install_ethosn_driver_stack.sh RUN bash /install/ubuntu_install_ethosn_driver_stack.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 8d11882098fb..1a5721c549ab 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -41,10 +41,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh /googletest -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_hexagon b/docker/Dockerfile.ci_hexagon index 1855e3a9c231..11b3041f3c56 100644 --- a/docker/Dockerfile.ci_hexagon +++ b/docker/Dockerfile.ci_hexagon @@ -37,10 +37,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_i386 b/docker/Dockerfile.ci_i386 index f1c0ee30b4d0..b96e4a33b459 100644 --- a/docker/Dockerfile.ci_i386 +++ b/docker/Dockerfile.ci_i386 @@ -49,7 +49,7 @@ ENV CARGO_HOME /opt/rust ENV PATH $PATH:$CARGO_HOME/bin ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu2004_install_python.sh /install/ubuntu2004_install_python.sh RUN bash /install/ubuntu2004_install_python.sh diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index e861b244d842..bab0cd0ebf9c 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -29,10 +29,10 @@ RUN bash /install/ubuntu_setup_tz.sh RUN apt-install-and-clear -y wget git sudo make parallel -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_minimal b/docker/Dockerfile.ci_minimal index 561b68a52b3a..e7eeb12f9d13 100644 --- a/docker/Dockerfile.ci_minimal +++ b/docker/Dockerfile.ci_minimal @@ -38,10 +38,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_riscv b/docker/Dockerfile.ci_riscv index 1256562a328c..d1b5a033b6e7 100644 --- a/docker/Dockerfile.ci_riscv +++ b/docker/Dockerfile.ci_riscv @@ -35,10 +35,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 000da7a31dd7..6860c51d7277 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -32,10 +32,10 @@ RUN bash /install/ubuntu_install_cmake_source.sh COPY install/ubuntu_install_googletest.sh /install/ubuntu_install_googletest.sh RUN bash /install/ubuntu_install_googletest.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.demo_android b/docker/Dockerfile.demo_android index b477b6d259f9..36aadbf1ee42 100644 --- a/docker/Dockerfile.demo_android +++ b/docker/Dockerfile.demo_android @@ -28,10 +28,10 @@ RUN bash /install/ubuntu_setup_tz.sh COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu1804_install_python.sh -RUN bash /install/ubuntu1804_install_python.sh 3.8 +RUN bash /install/ubuntu1804_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.demo_rocm b/docker/Dockerfile.demo_rocm index df458dd7dce4..4c6095ec4802 100644 --- a/docker/Dockerfile.demo_rocm +++ b/docker/Dockerfile.demo_rocm @@ -26,10 +26,10 @@ RUN bash /install/ubuntu_setup_tz.sh COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/Dockerfile.demo_vitis_ai b/docker/Dockerfile.demo_vitis_ai index 01b0b494bd9e..8cafc653fb6e 100644 --- a/docker/Dockerfile.demo_vitis_ai +++ b/docker/Dockerfile.demo_vitis_ai @@ -32,10 +32,10 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu_install_vitis_ai_core.sh /install/ubuntu_install_vitis_ai_core.sh RUN bash /install/ubuntu_install_vitis_ai_core.sh -ENV TVM_VENV /venv/apache-tvm-py3.8 +ENV TVM_VENV /venv/apache-tvm-py3.9 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.8 +RUN bash /install/ubuntu_install_python.sh 3.9 ENV PATH ${TVM_VENV}/bin:$PATH ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. diff --git a/docker/install/ubuntu2004_install_python.sh b/docker/install/ubuntu2004_install_python.sh index ece3c34fb0c3..33f7c90ada7c 100755 --- a/docker/install/ubuntu2004_install_python.sh +++ b/docker/install/ubuntu2004_install_python.sh @@ -30,15 +30,15 @@ trap cleanup 0 # Install python and pip. Don't modify this to add Python package dependencies, # instead modify install_python_package.sh apt-get update -apt-install-and-clear -y python3.8 python3.8-dev python3-pip -update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 +apt-install-and-clear -y python3.9 python3.9-dev python3-pip +update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 # Pin pip and setuptools versions # Hashes generated via: # $ pip download == # $ pip hash --algorithm sha256 .whl cat < base-requirements.txt -pip==23.3.2 --hash=sha256:5052d7889c1f9d05224cd41741acb7c5d6fa735ab34e339624a614eaaa7e7d76 -setuptools==58.4.0 --hash=sha256:e8b1d3127a0441fb99a130bcc3c2bf256c2d3ead3aba8fd400e5cbbaf788e036 +pip==24.2 --hash=sha256:2cd581cf58ab7fcfca4ce8efa6dcacd0de5bf8d0a3eb9ec927e07405f4d9e2a2 +setuptools==75.1.0 --hash=sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2 EOF pip3 install -r base-requirements.txt diff --git a/docker/install/ubuntu_install_cmake_source.sh b/docker/install/ubuntu_install_cmake_source.sh index 9085e19f4011..42f17f9ece89 100755 --- a/docker/install/ubuntu_install_cmake_source.sh +++ b/docker/install/ubuntu_install_cmake_source.sh @@ -20,19 +20,21 @@ set -e set -u set -o pipefail -if [ -z ${1+x} ]; then - version=3.24.0 -else - version=$1 -fi +CMAKE_VERSION="3.30.4" +CMAKE_SHA256="c759c97274f1e7aaaafcb1f0d261f9de9bf3a5d6ecb7e2df616324a46fe704b2" -v=$(echo $version | sed 's/\(.*\)\..*/\1/g') -echo "Installing cmake $version ($v)" -wget https://cmake.org/files/v${v}/cmake-${version}.tar.gz -tar xvf cmake-${version}.tar.gz -cd cmake-${version} -./bootstrap -make -j$(nproc) -make install -cd .. -rm -rf cmake-${version} cmake-${version}.tar.gz +# parse argument +CMAKE_VERSION=${1:-$CMAKE_VERSION} +CMAKE_SHA256=${2:-$CMAKE_SHA256} + +v=$(echo $CMAKE_VERSION | sed 's/\(.*\)\..*/\1/g') +echo "Installing cmake $CMAKE_VERSION ($v)" +wget https://cmake.org/files/v${v}/cmake-${CMAKE_VERSION}.tar.gz +echo "$CMAKE_SHA256" cmake-${CMAKE_VERSION}.tar.gz | sha256sum -c +tar xvf cmake-${CMAKE_VERSION}.tar.gz +pushd cmake-${CMAKE_VERSION} + ./bootstrap + make -j$(nproc) + make install +popd +rm -rf cmake-${CMAKE_VERSION} cmake-${CMAKE_VERSION}.tar.gz diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh index 19149909161e..17114e0efce8 100644 --- a/docker/install/ubuntu_install_jax.sh +++ b/docker/install/ubuntu_install_jax.sh @@ -20,16 +20,18 @@ set -e set -u set -o pipefail -# Install jax and jaxlib +JAX_VERSION=0.4.30 + +# Install jaxlib if [ "$1" == "cuda" ]; then - pip3 install --upgrade \ - jaxlib~=0.4.9 \ - "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install -U \ + "jax[cuda12]~=${JAX_VERSION}" \ + jaxlib~=${JAX_VERSION} else - pip3 install --upgrade \ - jaxlib~=0.4.9 \ - "jax[cpu]~=0.4.9" + pip3 install -U \ + jax~=${JAX_VERSION} \ + jaxlib~=${JAX_VERSION} fi # Install flax -pip3 install flax~=0.6.9 +pip3 install flax~=0.8.5 diff --git a/docker/install/ubuntu_install_llvm_from_source.sh b/docker/install/ubuntu_install_llvm_from_source.sh index 6bb13c804096..f1ef7d02be6e 100644 --- a/docker/install/ubuntu_install_llvm_from_source.sh +++ b/docker/install/ubuntu_install_llvm_from_source.sh @@ -63,7 +63,7 @@ cmake \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_USE_INTEL_JITEVENTS=ON \ -DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=ON \ - -DPYTHON_EXECUTABLE="$(which python3.8)" \ + -DPYTHON_EXECUTABLE="$(which python3.9)" \ -GNinja \ .. ninja install diff --git a/docker/install/ubuntu_install_python.sh b/docker/install/ubuntu_install_python.sh index 1f3ace61ef0f..664206570bc6 100755 --- a/docker/install/ubuntu_install_python.sh +++ b/docker/install/ubuntu_install_python.sh @@ -33,10 +33,13 @@ if [ "$#" -lt 1 ]; then fi PYTHON_VERSION=$1 -if [ "${PYTHON_VERSION}" != "3.7" ] && [ "${PYTHON_VERSION}" != "3.8" ]; then - echo "Only 3.7 and 3.8 versions are supported in this script." - exit -1 -fi +case "$PYTHON_VERSION" in + 3.7|3.8|3.9) ;; + *) + echo "Only 3.7, 3.8, and 3.9 versions are supported in this script." + exit -1 + ;; +esac apt-get update @@ -47,22 +50,23 @@ apt-install-and-clear -y \ apt-install-and-clear -y software-properties-common release=$(lsb_release -sc) -if [ "${release}" == "bionic" ]; then - if [ "${PYTHON_VERSION}" == "3.8" ]; then - add-apt-repository -y ppa:deadsnakes/ppa - fi -elif [ "${release}" == "focal" ]; then - if [ "${PYTHON_VERSION}" == "3.7" ]; then - add-apt-repository -y ppa:deadsnakes/ppa - fi -elif [ "${release}" == "jammy" ]; then - if [ "${PYTHON_VERSION}" == "3.8" ]; then - add-apt-repository -y ppa:deadsnakes/ppa - fi -else - echo "Don't know which version of python to install for lsb-release ${release}" - exit 2 -fi +case "${release}" in + bionic) + [ "${PYTHON_VERSION}" == "3.8" ] && add-apt-repository -y ppa:deadsnakes/ppa + ;; + focal) + [ "${PYTHON_VERSION}" == "3.7" ] && add-apt-repository -y ppa:deadsnakes/ppa + ;; + jammy) + if [ "${PYTHON_VERSION}" == "3.8" ] || [ "${PYTHON_VERSION}" == "3.9" ]; then + add-apt-repository -y ppa:deadsnakes/ppa + fi + ;; + *) + echo "Don't know which version of python to install for lsb-release ${release}" + exit 2 + ;; +esac # Install python and pip. Don't modify this to add Python package dependencies, # instead modify install_python_package.sh @@ -84,7 +88,6 @@ export PYTHONNOUSERSITE=1 venv_dir="$(python3 -c "import os.path;print(os.path.dirname(\"${TVM_VENV}\"))")" mkdir -p "${venv_dir}" python3 -mvenv "${TVM_VENV}" -. "${TVM_VENV}/bin/activate" # NOTE: Only in python3.9 does venv guarantee it creates the python3.X binary. # This is needed so that cmake's find_package(PythonInterp) works inside the venv. @@ -95,15 +98,15 @@ fi # Update pip to match version used to produce requirements-hashed.txt. This step # is necessary so that pip's dependency solver is recent. -pip_spec=$(cat /install/python/bootstrap/lockfiles/constraints-${PYTHON_VERSION}.txt | grep 'pip==') -pip3 install -U --require-hashes -r <(echo "${pip_spec}") \ +pip_spec=$(tac /install/python/bootstrap/lockfiles/constraints-${PYTHON_VERSION}.txt | grep -m 1 'pip==') +${TVM_VENV}/bin/pip install -U --require-hashes -r <(echo "${pip_spec}") \ -c /install/python/bootstrap/lockfiles/constraints-${PYTHON_VERSION}.txt # Python configuration -pip3 config set global.no-cache-dir true # Never cache packages +${TVM_VENV}/bin/pip config set global.no-cache-dir true # Never cache packages # Now install the remaining base packages. -pip3 install \ +${TVM_VENV}/bin/pip install \ --require-hashes \ -r /install/python/bootstrap/lockfiles/constraints-${PYTHON_VERSION}.txt @@ -114,7 +117,6 @@ setfacl -R -m group:tvm-venv:rwx "${TVM_VENV}" # Prevent further use of pip3 via the system. # There may be multiple (i.e. from python3-pip apt package and pip3 install -U). -deactivate while [ "$(which pip3)" != "" ]; do rm "$(which pip3)" done diff --git a/docker/install/ubuntu_install_spike_sim.sh b/docker/install/ubuntu_install_spike_sim.sh index 24a11d758c38..7bc2a992030c 100755 --- a/docker/install/ubuntu_install_spike_sim.sh +++ b/docker/install/ubuntu_install_spike_sim.sh @@ -39,43 +39,49 @@ export RISCV=$1 export PATH=$RISCV/bin:$PATH shift -sudo apt-install-and-clear -y --no-install-recommends device-tree-compiler +# Install dependency +apt-install-and-clear -y --no-install-recommends device-tree-compiler # Install spike mkdir /tmp/spike -cd /tmp/spike -# TODO: freeze version? -git clone https://github.com/riscv/riscv-isa-sim.git -pushd riscv-isa-sim -mkdir build -cd build -../configure --prefix=$RISCV --with-isa=RV32IMAC -make -j`nproc` -make install -popd - -# Install pk -git clone https://github.com/riscv/riscv-pk.git -pushd riscv-pk +pushd /tmp/spike + # TODO: freeze version? + git clone https://github.com/riscv/riscv-isa-sim.git + pushd riscv-isa-sim + mkdir build + cd build + ../configure --prefix=$RISCV --with-isa=RV32IMAC + make -j`nproc` + make install + popd -# rv32imac -mkdir build -pushd build -../configure --prefix=`pwd`/install --host=riscv64-unknown-elf --with-arch=rv32imac -make -j`nproc` -make install -cp ./pk $RISCV/riscv64-unknown-elf/bin/pk -popd + # Install pk + git clone https://github.com/riscv/riscv-pk.git + pushd riscv-pk + # With commit 47a2e87, we get the below compilation so we'll use the specific commit + # ../pk/pk.c: Assembler messages: + # ../pk/pk.c:122: Error: unknown CSR `ssp' + git checkout 1a52fa4 -git status + # rv32imac + mkdir build + pushd build + ../configure --prefix=`pwd`/install --host=riscv64-unknown-elf --with-arch=rv32imac + make -j`nproc` + make install + cp ./pk $RISCV/riscv64-unknown-elf/bin/pk + popd -# rv64imac -mkdir build64 -pushd build64 -../configure --prefix=`pwd`/install --host=riscv64-unknown-elf --with-arch=rv64imac -make -j`nproc` -make install -cp ./pk $RISCV/riscv64-unknown-elf/bin/pk64 + # rv64imac + mkdir build64 + pushd build64 + ../configure --prefix=`pwd`/install --host=riscv64-unknown-elf --with-arch=rv64imac + make -j`nproc` + make install + cp ./pk $RISCV/riscv64-unknown-elf/bin/pk64 + popd + popd +popd # cleanup rm -rf /tmp/spike diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 2225b7aef3b8..012b678916b3 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -21,5 +21,5 @@ set -u set -o pipefail pip3 install \ - keras==2.9 \ - tensorflow==2.9.1 + keras==3.5 \ + tensorflow==2.17.0 diff --git a/docker/install/ubuntu_install_tensorflow_aarch64.sh b/docker/install/ubuntu_install_tensorflow_aarch64.sh index fcd912a4478a..4b158948387b 100755 --- a/docker/install/ubuntu_install_tensorflow_aarch64.sh +++ b/docker/install/ubuntu_install_tensorflow_aarch64.sh @@ -25,5 +25,5 @@ apt-install-and-clear -y --no-install-recommends libhdf5-dev # h5py wheel tries to use the wrong .so file pip3 install \ numpy==1.23.5 \ - keras==2.9 \ - tensorflow-aarch64~=2.9.3 + keras==3.5 \ + tensorflow-aarch64~=2.16.1 diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 36e6dfc42794..8faabc022640 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -26,11 +26,11 @@ set -o pipefail TENSORFLOW_VERSION=$(python3 -c "import tensorflow; print(tensorflow.__version__)" 2> /dev/null) # Download, build and install flatbuffers -git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git -cd flatbuffers -cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" -make install -j8 -cd .. +git clone --branch=v24.3.25 --depth=1 --recursive https://github.com/google/flatbuffers.git +pushd flatbuffers + cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" + ninja install -j8 +popd # Install flatbuffers python packages. pip3 install flatbuffers @@ -41,22 +41,22 @@ pip3 install flatbuffers git clone https://github.com/tensorflow/tensorflow --branch=v${TENSORFLOW_VERSION} --depth 1 mkdir -p /opt/tflite -cd /opt/tflite -cmake \ - -DTFLITE_ENABLE_XNNPACK=OFF \ - /tensorflow/tensorflow/lite - -cmake --build . -cd - +pushd /opt/tflite + cmake -G Ninja \ + -DTFLITE_ENABLE_XNNPACK=OFF \ + /tensorflow/tensorflow/lite + cmake --build . +popd # Setup tflite from schema mkdir tflite -cp tensorflow/tensorflow/lite/schema/schema.fbs tflite -cd tflite -flatc --python schema.fbs +find / -name "schema.fbs" +cp /tensorflow/tensorflow/lite/stablehlo/schema/schema.fbs tflite +pushd tflite + flatc --python schema.fbs -cat <setup.py + cat <setup.py import setuptools setuptools.setup( @@ -77,12 +77,12 @@ setuptools.setup( ) EOM -cat <__init__.py + cat <__init__.py name = "tflite" EOM -# Install tflite over python3 -python3 setup.py install + # Install tflite over python3 + python3 setup.py install -cd .. +popd rm -rf tflite diff --git a/docker/install/ubuntu_install_verilator.sh b/docker/install/ubuntu_install_verilator.sh index 4aef7bc2da96..630746bd2162 100755 --- a/docker/install/ubuntu_install_verilator.sh +++ b/docker/install/ubuntu_install_verilator.sh @@ -21,17 +21,17 @@ set -u set -o pipefail # Verilator version -version="5.002" +VERILATOR_VERSION="5.002" # Install dependencies apt-get update && apt-install-and-clear -y autoconf g++ flex bison # Install Verilator -wget "https://github.com/verilator/verilator/archive/v$version.tar.gz" -tar xf "v$version.tar.gz" -rm "v$version.tar.gz" -cd "verilator-$version" -autoconf -./configure -make -j4 -make install +git clone --depth 1 --branch v${VERILATOR_VERSION} https://github.com/verilator/verilator +pushd verilator + autoconf + ./configure + make -j$(nproc) + make install +popd +rm -rf verilator diff --git a/docker/install/ubuntu_install_zephyr.sh b/docker/install/ubuntu_install_zephyr.sh index 3cef1e9c40c9..55bdacb0c0ce 100755 --- a/docker/install/ubuntu_install_zephyr.sh +++ b/docker/install/ubuntu_install_zephyr.sh @@ -47,9 +47,9 @@ release=$(lsb_release -sc) if [ "${release}" == "bionic" ]; then python_cmd="python3" elif [ "${release}" == "focal" ]; then - python_cmd="python3.8" + python_cmd="python3.9" elif [ "${release}" == "jammy" ]; then - python_cmd="python3.8" + python_cmd="python3.9" else echo "Don't know which version of python to use for Zephyr." exit 2 @@ -64,7 +64,7 @@ $python_cmd -m pip install west # Init ZephyrProject ZEPHYR_PROJECT_PATH=/opt/zephyrproject -bash /install/ubuntu_init_zephyr_project.sh ${ZEPHYR_PROJECT_PATH} +bash /install/ubuntu_init_zephyr_project.sh ${ZEPHYR_PROJECT_PATH} --branch v3.6-branch cd ${ZEPHYR_PROJECT_PATH} # As part of the build process, Zephyr needs to touch some symlinks in zephyr/misc/generated/syscalls_links (this path is relative to the diff --git a/docker/python/bootstrap/generate.sh b/docker/python/bootstrap/generate.sh index 116b8d8daee0..830c03b7b1c1 100755 --- a/docker/python/bootstrap/generate.sh +++ b/docker/python/bootstrap/generate.sh @@ -41,7 +41,7 @@ description = "" [tool.poetry.dependencies] python = "^$1" pip = "*" -poetry = "1.2.0b1" +poetry = "1.8.1" setuptools = "*" EOF @@ -50,7 +50,7 @@ EOF pwd . build/$1/_venv/bin/activate (mkdir -p build/$1/downloaded && cd build/$1/downloaded && pip3 download pip setuptools && pip3 install *.whl) - pip3 install poetry + pip3 install poetry poetry-plugin-export (cd build/$1 && poetry lock) # Now export requirements.txt and constraints.txt for @@ -73,7 +73,7 @@ with open("requirements.txt", "w") as f: EOF # For - (cd build/$1 && poetry export -o constraints.txt) + (cd build/$1 && poetry export -f constraints.txt -o constraints.txt) (cd build/$1 && python3 <= "3.9" and python_version < "4.0" \ + --hash=sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c \ + --hash=sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613 +cachecontrol==0.14.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:7db1195b41c81f8274a7bbd97c956f44e8348265a1bc7641c37dfebc39f0c938 \ + --hash=sha256:f5bf3f0620c38db2e5122c0726bdebb0d16869de966ea6a2befe92470b740ea0 +certifi==2024.8.30 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8 \ + --hash=sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9 +cffi==1.17.1 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "darwin" or sys_platform == "linux") and (sys_platform == "darwin" or platform_python_implementation != "PyPy") \ + --hash=sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8 \ + --hash=sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2 \ + --hash=sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1 \ + --hash=sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15 \ + --hash=sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36 \ + --hash=sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824 \ + --hash=sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8 \ + --hash=sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36 \ + --hash=sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17 \ + --hash=sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf \ + --hash=sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc \ + --hash=sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3 \ + --hash=sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed \ + --hash=sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702 \ + --hash=sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1 \ + --hash=sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8 \ + --hash=sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903 \ + --hash=sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6 \ + --hash=sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d \ + --hash=sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b \ + --hash=sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e \ + --hash=sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be \ + --hash=sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c \ + --hash=sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683 \ + --hash=sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9 \ + --hash=sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c \ + --hash=sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8 \ + --hash=sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1 \ + --hash=sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4 \ + --hash=sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655 \ + --hash=sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67 \ + --hash=sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595 \ + --hash=sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0 \ + --hash=sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65 \ + --hash=sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41 \ + --hash=sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6 \ + --hash=sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401 \ + --hash=sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6 \ + --hash=sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3 \ + --hash=sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16 \ + --hash=sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93 \ + --hash=sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e \ + --hash=sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4 \ + --hash=sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964 \ + --hash=sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c \ + --hash=sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576 \ + --hash=sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0 \ + --hash=sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3 \ + --hash=sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662 \ + --hash=sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3 \ + --hash=sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff \ + --hash=sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5 \ + --hash=sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd \ + --hash=sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f \ + --hash=sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5 \ + --hash=sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14 \ + --hash=sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d \ + --hash=sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9 \ + --hash=sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7 \ + --hash=sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382 \ + --hash=sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a \ + --hash=sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e \ + --hash=sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a \ + --hash=sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4 \ + --hash=sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99 \ + --hash=sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87 \ + --hash=sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 +cleo==2.1.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:0b2c880b5d13660a7ea651001fb4acb527696c01f15c9ee650f377aa543fd523 \ + --hash=sha256:4a31bd4dd45695a64ee3c4758f583f134267c2bc518d8ae9a29cf237d009b07e +colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and os_name == "nt" \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 +crashtest==0.4.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce \ + --hash=sha256:8d23eac5fa660409f57472e3851dab7ac18aba459a8d19cbbba86d3d5aecd2a5 +cryptography==43.0.1 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "linux" \ + --hash=sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494 \ + --hash=sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806 \ + --hash=sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d \ + --hash=sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062 \ + --hash=sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2 \ + --hash=sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4 \ + --hash=sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1 \ + --hash=sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85 \ + --hash=sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84 \ + --hash=sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042 \ + --hash=sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d \ + --hash=sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962 \ + --hash=sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2 \ + --hash=sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa \ + --hash=sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d \ + --hash=sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365 \ + --hash=sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96 \ + --hash=sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47 \ + --hash=sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d \ + --hash=sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d \ + --hash=sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c \ + --hash=sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb \ + --hash=sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277 \ + --hash=sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172 \ + --hash=sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034 \ + --hash=sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a \ + --hash=sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289 +distlib==0.3.8 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784 \ + --hash=sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64 +dulwich==0.21.7 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:0fc3078a1ba04c588fabb0969d3530efd5cd1ce2cf248eefb6baf7cbc15fc285 \ + --hash=sha256:10893105c6566fc95bc2a67b61df7cc1e8f9126d02a1df6a8b2b82eb59db8ab9 \ + --hash=sha256:12d61334a575474e707614f2e93d6ed4cdae9eb47214f9277076d9e5615171d3 \ + --hash=sha256:2590e9b431efa94fc356ae33b38f5e64f1834ec3a94a6ac3a64283b206d07aa3 \ + --hash=sha256:25c3ab8fb2e201ad2031ddd32e4c68b7c03cb34b24a5ff477b7a7dcef86372f5 \ + --hash=sha256:274c18ec3599a92a9b67abaf110e4f181a4f779ee1aaab9e23a72e89d71b2bd9 \ + --hash=sha256:29bb5c1d70eba155ded41ed8a62be2f72edbb3c77b08f65b89c03976292f6d1b \ + --hash=sha256:2bc12697f0918bee324c18836053644035362bb3983dc1b210318f2fed1d7132 \ + --hash=sha256:2e2c66888207b71cd1daa2acb06d3984a6bc13787b837397a64117aa9fc5936a \ + --hash=sha256:404b8edeb3c3a86c47c0a498699fc064c93fa1f8bab2ffe919e8ab03eafaaad3 \ + --hash=sha256:40dcbd29ba30ba2c5bfbab07a61a5f20095541d5ac66d813056c122244df4ac0 \ + --hash=sha256:460b3849d5c3d3818a80743b4f7a0094c893c559f678e56a02fff570b49a644a \ + --hash=sha256:460ba74bdb19f8d498786ae7776745875059b1178066208c0fd509792d7f7bfc \ + --hash=sha256:4637cbd8ed1012f67e1068aaed19fcc8b649bcf3e9e26649826a303298c89b9d \ + --hash=sha256:471305af74790827fcbafe330fc2e8bdcee4fb56ca1177c8c481b1c8f806c4a4 \ + --hash=sha256:4a043b90958cec866b4edc6aef5fe3c2c96a664d0b357e1682a46f6c477273c4 \ + --hash=sha256:4b09bc3a64fb70132ec14326ecbe6e0555381108caff3496898962c4136a48c6 \ + --hash=sha256:4bc4c5366eaf26dda3fdffe160a3b515666ed27c2419f1d483da285ac1411de0 \ + --hash=sha256:4c51058ec4c0b45dc5189225b9e0c671b96ca9713c1daf71d622c13b0ab07681 \ + --hash=sha256:4f18f0a311fb7734b033a3101292b932158cade54b74d1c44db519e42825e5a2 \ + --hash=sha256:61e3451bd3d3844f2dca53f131982553be4d1b1e1ebd9db701843dd76c4dba31 \ + --hash=sha256:62bfb26bdce869cd40be443dfd93143caea7089b165d2dcc33de40f6ac9d812a \ + --hash=sha256:675a612ce913081beb0f37b286891e795d905691dfccfb9bf73721dca6757cde \ + --hash=sha256:6bd69921fdd813b7469a3c77bc75c1783cc1d8d72ab15a406598e5a3ba1a1503 \ + --hash=sha256:6c589468e5c0cd84e97eb7ec209ab005a2cb69399e8c5861c3edfe38989ac3a8 \ + --hash=sha256:6de6f8de4a453fdbae8062a6faa652255d22a3d8bce0cd6d2d6701305c75f2b3 \ + --hash=sha256:739b191f61e1c4ce18ac7d520e7a7cbda00e182c3489552408237200ce8411ad \ + --hash=sha256:74700e4c7d532877355743336c36f51b414d01e92ba7d304c4f8d9a5946dbc81 \ + --hash=sha256:7836da3f4110ce684dcd53489015fb7fa94ed33c5276e3318b8b1cbcb5b71e08 \ + --hash=sha256:7bca4b86e96d6ef18c5bc39828ea349efb5be2f9b1f6ac9863f90589bac1084d \ + --hash=sha256:7d8ab29c660125db52106775caa1f8f7f77a69ed1fe8bc4b42bdf115731a25bf \ + --hash=sha256:808e8b9cc0aa9ac74870b49db4f9f39a52fb61694573f84b9c0613c928d4caf8 \ + --hash=sha256:817822f970e196e757ae01281ecbf21369383285b9f4a83496312204cf889b8c \ + --hash=sha256:8278835e168dd097089f9e53088c7a69c6ca0841aef580d9603eafe9aea8c358 \ + --hash=sha256:858842b30ad6486aacaa607d60bab9c9a29e7c59dc2d9cb77ae5a94053878c08 \ + --hash=sha256:869eb7be48243e695673b07905d18b73d1054a85e1f6e298fe63ba2843bb2ca1 \ + --hash=sha256:8869fc8ec3dda743e03d06d698ad489b3705775fe62825e00fa95aa158097fc0 \ + --hash=sha256:8929c37986c83deb4eb500c766ee28b6670285b512402647ee02a857320e377c \ + --hash=sha256:a0650ec77d89cb947e3e4bbd4841c96f74e52b4650830112c3057a8ca891dc2f \ + --hash=sha256:a7b5624b02ef808cdc62dabd47eb10cd4ac15e8ac6df9e2e88b6ac6b40133673 \ + --hash=sha256:a9e9c66833cea580c3ac12927e4b9711985d76afca98da971405d414de60e968 \ + --hash=sha256:b0d2e4485b98695bf95350ce9d38b1bb0aaac2c34ad00a0df789aa33c934469b \ + --hash=sha256:c01a735b9a171dcb634a97a3cec1b174cfbfa8e840156870384b633da0460f18 \ + --hash=sha256:c3a539b4696a42fbdb7412cb7b66a4d4d332761299d3613d90a642923c7560e1 \ + --hash=sha256:c3d1685f320907a52c40fd5890627945c51f3a5fa4bcfe10edb24fec79caadec \ + --hash=sha256:c92e72c43c9e9e936b01a57167e0ea77d3fd2d82416edf9489faa87278a1cdf7 \ + --hash=sha256:cc1e11be527ac06316539b57a7688bcb1b6a3e53933bc2f844397bc50734e9ae \ + --hash=sha256:ce8db196e79c1f381469410d26fb1d8b89c6b87a4e7f00ff418c22a35121405c \ + --hash=sha256:d05d3c781bc74e2c2a2a8f4e4e2ed693540fbe88e6ac36df81deac574a6dad99 \ + --hash=sha256:d097e963eb6b9fa53266146471531ad9c6765bf390849230311514546ed64db2 \ + --hash=sha256:d4a2d76c96426e791556836ef43542b639def81be4f1d6d4322cd886c115eae1 \ + --hash=sha256:d4c0110798099bb7d36a110090f2688050703065448895c4f53ade808d889dd3 \ + --hash=sha256:d54c9d0e845be26f65f954dff13a1cd3f2b9739820c19064257b8fd7435ab263 \ + --hash=sha256:d5882e70b74ac3c736a42d3fdd4f5f2e6570637f59ad5d3e684760290b58f041 \ + --hash=sha256:d62446797163317a397a10080c6397ffaaca51a7804c0120b334f8165736c56a \ + --hash=sha256:d96ca5e0dde49376fbcb44f10eddb6c30284a87bd03bb577c59bb0a1f63903fa \ + --hash=sha256:e0064363bd5e814359657ae32517fa8001e8573d9d040bd997908d488ab886ed \ + --hash=sha256:e138d516baa6b5bafbe8f030eccc544d0d486d6819b82387fc0e285e62ef5261 \ + --hash=sha256:e1957b65f96e36c301e419d7adaadcff47647c30eb072468901bb683b1000bc5 \ + --hash=sha256:e25953c7acbbe4e19650d0225af1c0c0e6882f8bddd2056f75c1cc2b109b88ad \ + --hash=sha256:e274cebaf345f0b1e3b70197f2651de92b652386b68020cfd3bf61bc30f6eaaa \ + --hash=sha256:e598d743c6c0548ebcd2baf94aa9c8bfacb787ea671eeeb5828cfbd7d56b552f \ + --hash=sha256:e84cc606b1f581733df4350ca4070e6a8b30be3662bbb81a590b177d0c996c91 \ + --hash=sha256:ecd315847dea406a4decfa39d388a2521e4e31acde3bd9c2609c989e817c6d62 \ + --hash=sha256:ed60d1f610ef6437586f7768254c2a93820ccbd4cfdac7d182cf2d6e615969bb \ + --hash=sha256:f34bf9b9fa9308376263fd9ac43143c7c09da9bc75037bb75c6c2423a151b92c \ + --hash=sha256:f6c88acb60a1f4d31bd6d13bfba465853b3df940ee4a0f2a3d6c7a0778c705b7 \ + --hash=sha256:fa4d14767cf7a49c9231c2e52cb2a3e90d0c83f843eb6a2ca2b5d81d254cf6b9 \ + --hash=sha256:ffc27fb063f740712e02b4d2f826aee8bbed737ed799962fef625e2ce56e2d29 +fastjsonschema==2.20.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:3d48fc5300ee96f5d116f10fe6f28d938e6008f59a6a025c2649475b87f76a23 \ + --hash=sha256:5875f0b0fa7a0043a91e93a9b8f793bcbbba9691e7fd83dca95c28ba26d21f0a +filelock==3.16.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0 \ + --hash=sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435 +idna==3.10 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 +importlib-metadata==8.5.0 ; python_version >= "3.9" and python_version < "3.12" \ + --hash=sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b \ + --hash=sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7 +installer==0.7.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53 \ + --hash=sha256:a26d3e3116289bb08216e0d0f7d925fcef0b0194eedfa0c944bcaaa106c4b631 +jaraco-classes==3.4.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd \ + --hash=sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790 +jeepney==0.8.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "linux" \ + --hash=sha256:5efe48d255973902f6badc3ce55e2aa6c5c3b3bc642059ef3a91247bcfcc5806 \ + --hash=sha256:c0a454ad016ca575060802ee4d590dd912e35c122fa04e70306de3d076cce755 +keyring==24.3.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ + --hash=sha256:df38a4d7419a6a60fea5cef1e45a948a3e8430dd12ad88b0f423c5c143906218 +more-itertools==10.5.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:037b0d3203ce90cca8ab1defbbdac29d5f993fc20131f3664dc8d6acfa872aef \ + --hash=sha256:5482bfef7849c25dc3c6dd53a6173ae4795da2a41a80faea6700d9f5846c5da6 +msgpack==1.1.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:06f5fd2f6bb2a7914922d935d3b8bb4a7fff3a9a91cfce6d06c13bc42bec975b \ + --hash=sha256:071603e2f0771c45ad9bc65719291c568d4edf120b44eb36324dcb02a13bfddf \ + --hash=sha256:0907e1a7119b337971a689153665764adc34e89175f9a34793307d9def08e6ca \ + --hash=sha256:0f92a83b84e7c0749e3f12821949d79485971f087604178026085f60ce109330 \ + --hash=sha256:115a7af8ee9e8cddc10f87636767857e7e3717b7a2e97379dc2054712693e90f \ + --hash=sha256:13599f8829cfbe0158f6456374e9eea9f44eee08076291771d8ae93eda56607f \ + --hash=sha256:17fb65dd0bec285907f68b15734a993ad3fc94332b5bb21b0435846228de1f39 \ + --hash=sha256:2137773500afa5494a61b1208619e3871f75f27b03bcfca7b3a7023284140247 \ + --hash=sha256:3180065ec2abbe13a4ad37688b61b99d7f9e012a535b930e0e683ad6bc30155b \ + --hash=sha256:398b713459fea610861c8a7b62a6fec1882759f308ae0795b5413ff6a160cf3c \ + --hash=sha256:3d364a55082fb2a7416f6c63ae383fbd903adb5a6cf78c5b96cc6316dc1cedc7 \ + --hash=sha256:3df7e6b05571b3814361e8464f9304c42d2196808e0119f55d0d3e62cd5ea044 \ + --hash=sha256:41c991beebf175faf352fb940bf2af9ad1fb77fd25f38d9142053914947cdbf6 \ + --hash=sha256:42f754515e0f683f9c79210a5d1cad631ec3d06cea5172214d2176a42e67e19b \ + --hash=sha256:452aff037287acb1d70a804ffd022b21fa2bb7c46bee884dbc864cc9024128a0 \ + --hash=sha256:4676e5be1b472909b2ee6356ff425ebedf5142427842aa06b4dfd5117d1ca8a2 \ + --hash=sha256:46c34e99110762a76e3911fc923222472c9d681f1094096ac4102c18319e6468 \ + --hash=sha256:471e27a5787a2e3f974ba023f9e265a8c7cfd373632247deb225617e3100a3c7 \ + --hash=sha256:4a1964df7b81285d00a84da4e70cb1383f2e665e0f1f2a7027e683956d04b734 \ + --hash=sha256:4b51405e36e075193bc051315dbf29168d6141ae2500ba8cd80a522964e31434 \ + --hash=sha256:4d1b7ff2d6146e16e8bd665ac726a89c74163ef8cd39fa8c1087d4e52d3a2325 \ + --hash=sha256:53258eeb7a80fc46f62fd59c876957a2d0e15e6449a9e71842b6d24419d88ca1 \ + --hash=sha256:534480ee5690ab3cbed89d4c8971a5c631b69a8c0883ecfea96c19118510c846 \ + --hash=sha256:58638690ebd0a06427c5fe1a227bb6b8b9fdc2bd07701bec13c2335c82131a88 \ + --hash=sha256:58dfc47f8b102da61e8949708b3eafc3504509a5728f8b4ddef84bd9e16ad420 \ + --hash=sha256:59caf6a4ed0d164055ccff8fe31eddc0ebc07cf7326a2aaa0dbf7a4001cd823e \ + --hash=sha256:5dbad74103df937e1325cc4bfeaf57713be0b4f15e1c2da43ccdd836393e2ea2 \ + --hash=sha256:5e1da8f11a3dd397f0a32c76165cf0c4eb95b31013a94f6ecc0b280c05c91b59 \ + --hash=sha256:646afc8102935a388ffc3914b336d22d1c2d6209c773f3eb5dd4d6d3b6f8c1cb \ + --hash=sha256:64fc9068d701233effd61b19efb1485587560b66fe57b3e50d29c5d78e7fef68 \ + --hash=sha256:65553c9b6da8166e819a6aa90ad15288599b340f91d18f60b2061f402b9a4915 \ + --hash=sha256:685ec345eefc757a7c8af44a3032734a739f8c45d1b0ac45efc5d8977aa4720f \ + --hash=sha256:6ad622bf7756d5a497d5b6836e7fc3752e2dd6f4c648e24b1803f6048596f701 \ + --hash=sha256:73322a6cc57fcee3c0c57c4463d828e9428275fb85a27aa2aa1a92fdc42afd7b \ + --hash=sha256:74bed8f63f8f14d75eec75cf3d04ad581da6b914001b474a5d3cd3372c8cc27d \ + --hash=sha256:79ec007767b9b56860e0372085f8504db5d06bd6a327a335449508bbee9648fa \ + --hash=sha256:7a946a8992941fea80ed4beae6bff74ffd7ee129a90b4dd5cf9c476a30e9708d \ + --hash=sha256:7ad442d527a7e358a469faf43fda45aaf4ac3249c8310a82f0ccff9164e5dccd \ + --hash=sha256:7c9a35ce2c2573bada929e0b7b3576de647b0defbd25f5139dcdaba0ae35a4cc \ + --hash=sha256:7e7b853bbc44fb03fbdba34feb4bd414322180135e2cb5164f20ce1c9795ee48 \ + --hash=sha256:879a7b7b0ad82481c52d3c7eb99bf6f0645dbdec5134a4bddbd16f3506947feb \ + --hash=sha256:8a706d1e74dd3dea05cb54580d9bd8b2880e9264856ce5068027eed09680aa74 \ + --hash=sha256:8a84efb768fb968381e525eeeb3d92857e4985aacc39f3c47ffd00eb4509315b \ + --hash=sha256:8cf9e8c3a2153934a23ac160cc4cba0ec035f6867c8013cc6077a79823370346 \ + --hash=sha256:8da4bf6d54ceed70e8861f833f83ce0814a2b72102e890cbdfe4b34764cdd66e \ + --hash=sha256:8e59bca908d9ca0de3dc8684f21ebf9a690fe47b6be93236eb40b99af28b6ea6 \ + --hash=sha256:914571a2a5b4e7606997e169f64ce53a8b1e06f2cf2c3a7273aa106236d43dd5 \ + --hash=sha256:a51abd48c6d8ac89e0cfd4fe177c61481aca2d5e7ba42044fd218cfd8ea9899f \ + --hash=sha256:a52a1f3a5af7ba1c9ace055b659189f6c669cf3657095b50f9602af3a3ba0fe5 \ + --hash=sha256:ad33e8400e4ec17ba782f7b9cf868977d867ed784a1f5f2ab46e7ba53b6e1e1b \ + --hash=sha256:b4c01941fd2ff87c2a934ee6055bda4ed353a7846b8d4f341c428109e9fcde8c \ + --hash=sha256:bce7d9e614a04d0883af0b3d4d501171fbfca038f12c77fa838d9f198147a23f \ + --hash=sha256:c40ffa9a15d74e05ba1fe2681ea33b9caffd886675412612d93ab17b58ea2fec \ + --hash=sha256:c5a91481a3cc573ac8c0d9aace09345d989dc4a0202b7fcb312c88c26d4e71a8 \ + --hash=sha256:c921af52214dcbb75e6bdf6a661b23c3e6417f00c603dd2070bccb5c3ef499f5 \ + --hash=sha256:d46cf9e3705ea9485687aa4001a76e44748b609d260af21c4ceea7f2212a501d \ + --hash=sha256:d8ce0b22b890be5d252de90d0e0d119f363012027cf256185fc3d474c44b1b9e \ + --hash=sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e \ + --hash=sha256:e0856a2b7e8dcb874be44fea031d22e5b3a19121be92a1e098f46068a11b0870 \ + --hash=sha256:e1f3c3d21f7cf67bcf2da8e494d30a75e4cf60041d98b3f79875afb5b96f3a3f \ + --hash=sha256:f1ba6136e650898082d9d5a5217d5906d1e138024f836ff48691784bbe1adf96 \ + --hash=sha256:f3e9b4936df53b970513eac1758f3882c88658a220b58dcc1e39606dccaaf01c \ + --hash=sha256:f80bc7d47f76089633763f952e67f8214cb7b3ee6bfa489b3cb6a84cfac114cd \ + --hash=sha256:fd2906780f25c8ed5d7b323379f6138524ba793428db5d0e9d226d3fa6aa1788 +packaging==24.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ + --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 +pexpect==4.9.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523 \ + --hash=sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f +pip==24.2 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:2cd581cf58ab7fcfca4ce8efa6dcacd0de5bf8d0a3eb9ec927e07405f4d9e2a2 \ + --hash=sha256:5b5e490b5e9cb275c879595064adce9ebd31b854e3e803740b72f9ccf34a45b8 +pkginfo==1.11.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:2e0dca1cf4c8e39644eed32408ea9966ee15e0d324c62ba899a393b3c6b467aa \ + --hash=sha256:bfa76a714fdfc18a045fcd684dbfc3816b603d9d075febef17cb6582bea29573 +platformdirs==4.3.6 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907 \ + --hash=sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb +poetry-core==1.9.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:4e0c9c6ad8cf89956f03b308736d84ea6ddb44089d16f2adc94050108ec1f5a1 \ + --hash=sha256:fa7a4001eae8aa572ee84f35feb510b321bd652e5cf9293249d62853e1f935a2 +poetry-plugin-export==1.8.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:1fa6168a85d59395d835ca564bc19862a7c76061e60c3e7dfaec70d50937fc61 \ + --hash=sha256:adbe232cfa0cc04991ea3680c865cf748bff27593b9abcb1f35fb50ed7ba2c22 +poetry==1.8.3 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:67f4eb68288eab41e841cc71a00d26cf6bdda9533022d0189a145a34d0a35f48 \ + --hash=sha256:88191c69b08d06f9db671b793d68f40048e8904c0718404b63dcc2b5aec62d13 +ptyprocess==0.7.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35 \ + --hash=sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220 +pycparser==2.22 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "darwin" or sys_platform == "linux") and (sys_platform == "darwin" or platform_python_implementation != "PyPy") \ + --hash=sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6 \ + --hash=sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc +pyproject-hooks==1.1.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 +pywin32-ctypes==0.2.3 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" \ + --hash=sha256:8a1513379d709975552d202d942d9837758905c8d01eb82b8bcc30918929e7b8 \ + --hash=sha256:d162dc04946d704503b2edc4d55f3dba5c1d539ead017afa00142c38b9885755 +rapidfuzz==3.10.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:094c26116d55bf9c53abd840d08422f20da78ec4c4723e5024322321caedca48 \ + --hash=sha256:0ec338d5f4ad8d9339a88a08db5c23e7f7a52c2b2a10510c48a0cef1fb3f0ddc \ + --hash=sha256:10fdad800441b9c97d471a937ba7d42625f1b530db05e572f1cb7d401d95c893 \ + --hash=sha256:116c71a81e046ba56551d8ab68067ca7034d94b617545316d460a452c5c3c289 \ + --hash=sha256:1af60988d47534246d9525f77288fdd9de652608a4842815d9018570b959acc6 \ + --hash=sha256:2026651761bf83a0f31495cc0f70840d5c0d54388f41316e3f9cb51bd85e49a5 \ + --hash=sha256:20bd153aacc244e4c907d772c703fea82754c4db14f8aa64d75ff81b7b8ab92d \ + --hash=sha256:26de93e6495078b6af4c4d93a42ca067b16cc0e95699526c82ab7d1025b4d3bf \ + --hash=sha256:288f6f6e7410cacb115fb851f3f18bf0e4231eb3f6cb5bd1cec0e7b25c4d039d \ + --hash=sha256:2db9187f3acf3cd33424ecdbaad75414c298ecd1513470df7bda885dcb68cc15 \ + --hash=sha256:2e9be5d05cd960914024412b5406fb75a82f8562f45912ff86255acbfdbfb78e \ + --hash=sha256:2fe5783676f0afba4a522c80b15e99dbf4e393c149ab610308a8ef1f04c6bcc8 \ + --hash=sha256:3084161fc3e963056232ef8d937449a2943852e07101f5a136c8f3cfa4119217 \ + --hash=sha256:34f213d59219a9c3ca14e94a825f585811a68ac56b4118b4dc388b5b14afc108 \ + --hash=sha256:399b9b79ccfcf50ca3bad7692bc098bb8eade88d7d5e15773b7f866c91156d0c \ + --hash=sha256:43dfc5e733808962a822ff6d9c29f3039a3cfb3620706f5953e17cfe4496724c \ + --hash=sha256:457827ba82261aa2ae6ac06a46d0043ab12ba7216b82d87ae1434ec0f29736d6 \ + --hash=sha256:47aca565a39c9a6067927871973ca827023e8b65ba6c5747f4c228c8d7ddc04f \ + --hash=sha256:4bd1a7676ee2a4c8e2f7f2550bece994f9f89e58afb96088964145a83af7408b \ + --hash=sha256:4dd3d8443970eaa02ab5ae45ce584b061f2799cd9f7e875190e2617440c1f9d4 \ + --hash=sha256:4df75b3ebbb8cfdb9bf8b213b168620b88fd92d0c16a8bc9f9234630b282db59 \ + --hash=sha256:50484d563f8bfa723c74c944b0bb15b9e054db9c889348c8c307abcbee75ab92 \ + --hash=sha256:50e3d0c72ea15391ba9531ead7f2068a67c5b18a6a365fef3127583aaadd1725 \ + --hash=sha256:545fc04f2d592e4350f59deb0818886c1b444ffba3bec535b4fbb97191aaf769 \ + --hash=sha256:56fd15ea8f4c948864fa5ebd9261c67cf7b89a1c517a0caef4df75446a7af18c \ + --hash=sha256:5897242d455461f2c5b82d7397b29341fd11e85bf3608a522177071044784ee8 \ + --hash=sha256:5d350864269d56f51ab81ab750c9259ae5cad3152c0680baef143dcec92206a1 \ + --hash=sha256:5dd6eec15b13329abe66cc241b484002ecb0e17d694491c944a22410a6a9e5e2 \ + --hash=sha256:63e4c175cbce8c3adc22dca5e6154588ae673f6c55374d156f3dac732c88d7de \ + --hash=sha256:69ef5b363afff7150a1fbe788007e307b9802a2eb6ad92ed51ab94e6ad2674c6 \ + --hash=sha256:6b62af27e65bb39276a66533655a2fa3c60a487b03935721c45b7809527979be \ + --hash=sha256:6cd67d3d017296d98ff505529104299f78433e4b8af31b55003d901a62bbebe9 \ + --hash=sha256:718c9bd369288aca5fa929df6dbf66fdbe9768d90940a940c0b5cdc96ade4309 \ + --hash=sha256:76a35e9e19a7c883c422ffa378e9a04bc98cb3b29648c5831596401298ee51e6 \ + --hash=sha256:7947a425d1be3e744707ee58c6cb318b93a56e08f080722dcc0347e0b7a1bb9a \ + --hash=sha256:79e7f98525b60b3c14524e0a4e1fedf7654657b6e02eb25f1be897ab097706f3 \ + --hash=sha256:7c4c82b1689b23b1b5e6a603164ed2be41b6f6de292a698b98ba2381e889eb9d \ + --hash=sha256:7dc87073ba3a40dd65591a2100aa71602107443bf10770579ff9c8a3242edb94 \ + --hash=sha256:7f3a6aa6e70fc27e4ff5c479f13cc9fc26a56347610f5f8b50396a0d344c5f55 \ + --hash=sha256:803f255f10d63420979b1909ef976e7d30dec42025c9b067fc1d2040cc365a7e \ + --hash=sha256:884453860de029380dded8f3c1918af2d8eb5adf8010261645c7e5c88c2b5428 \ + --hash=sha256:886882367dbc985f5736356105798f2ae6e794e671fc605476cbe2e73838a9bb \ + --hash=sha256:8a6405d34c394c65e4f73a1d300c001f304f08e529d2ed6413b46ee3037956eb \ + --hash=sha256:916a6abf3632e592b937c3d04c00a6efadd8fd30539cdcd4e6e4d92be7ca5d90 \ + --hash=sha256:9178277f72d144a6c7704d7ae7fa15b7b86f0f0796f0e1049c7b4ef748a662ef \ + --hash=sha256:949b5e9eeaa4ecb4c7e9c2a4689dddce60929dd1ff9c76a889cdbabe8bbf2171 \ + --hash=sha256:94c48b4a2a4b1d22246f48e2b11cae01ec7d23f0c9123f8bb822839ad79d0a88 \ + --hash=sha256:96ad46f5f56f70fab2be9e5f3165a21be58d633b90bf6e67fc52a856695e4bcf \ + --hash=sha256:98f6ebe28831a482981ecfeedc8237047878424ad0c1add2c7f366ba44a20452 \ + --hash=sha256:9eac95b4278bd53115903d89118a2c908398ee8bdfd977ae844f1bd2b02b917c \ + --hash=sha256:a425a0a868cf8e9c6e93e1cda4b758cdfd314bb9a4fc916c5742c934e3613480 \ + --hash=sha256:a68e3724b7dab761c01816aaa64b0903734d999d5589daf97c14ef5cc0629a8e \ + --hash=sha256:a86d5d1d75e61df060c1e56596b6b0a4422a929dff19cc3dbfd5eee762c86b61 \ + --hash=sha256:a9b8f51e08c3f983d857c3889930af9ddecc768453822076683664772d87e374 \ + --hash=sha256:aadce42147fc09dcef1afa892485311e824c050352e1aa6e47f56b9b27af4cf0 \ + --hash=sha256:ae7966f205b5a7fde93b44ca8fed37c1c8539328d7f179b1197de34eceaceb5f \ + --hash=sha256:b0445fa9880ead81f5a7d0efc0b9c977a947d8052c43519aceeaf56eabaf6843 \ + --hash=sha256:b0732343cdc4273b5921268026dd7266f75466eb21873cb7635a200d9d9c3fac \ + --hash=sha256:b11a127ac590fc991e8a02c2d7e1ac86e8141c92f78546f18b5c904064a0552c \ + --hash=sha256:b33e13e537e3afd1627d421a142a12bbbe601543558a391a6fae593356842f6e \ + --hash=sha256:b5363932a5aab67010ae1a6205c567d1ef256fb333bc23c27582481606be480c \ + --hash=sha256:b54853c2371bf0e38d67da379519deb6fbe70055efb32f6607081641af3dc752 \ + --hash=sha256:b67cc21a14327a0eb0f47bc3d7e59ec08031c7c55220ece672f9476e7a8068d3 \ + --hash=sha256:bb0013795b40db5cf361e6f21ee7cda09627cf294977149b50e217d7fe9a2f03 \ + --hash=sha256:bd393683129f446a75d8634306aed7e377627098a1286ff3af2a4f1736742820 \ + --hash=sha256:c038b9939da3035afb6cb2f465f18163e8f070aba0482923ecff9443def67178 \ + --hash=sha256:c50bc308fa29767ed8f53a8d33b7633a9e14718ced038ed89d41b886e301da32 \ + --hash=sha256:c582c46b1bb0b19f1a5f4c1312f1b640c21d78c371a6615c34025b16ee56369b \ + --hash=sha256:c77a7330dd15c7eb5fd3631dc646fc96327f98db8181138766bd14d3e905f0ba \ + --hash=sha256:c9e29a13d2fd9be3e7d8c26c7ef4ba60b5bc7efbc9dbdf24454c7e9ebba31768 \ + --hash=sha256:ca366c2e2a54e2f663f4529b189fdeb6e14d419b1c78b754ec1744f3c01070d4 \ + --hash=sha256:ce19887268e90ee81a3957eef5e46a70ecc000713796639f83828b950343f49e \ + --hash=sha256:cffbc50e0767396ed483900900dd58ce4351bc0d40e64bced8694bd41864cc71 \ + --hash=sha256:d29d1b9857c65f8cb3a29270732e1591b9bacf89de9d13fa764f79f07d8f1fd2 \ + --hash=sha256:d4688862f957c8629d557d084f20b2d803f8738b6c4066802a0b1cc472e088d9 \ + --hash=sha256:e5ddb2388610799fc46abe389600625058f2a73867e63e20107c5ad5ffa57c47 \ + --hash=sha256:e89605afebbd2d4b045bccfdc12a14b16fe8ccbae05f64b4b4c64a97dad1c891 \ + --hash=sha256:ea2da0459b951ee461bd4e02b8904890bd1c4263999d291c5cd01e6620177ad4 \ + --hash=sha256:ec9139baa3f85b65adc700eafa03ed04995ca8533dd56c924f0e458ffec044ab \ + --hash=sha256:eda4c661e68dddd56c8fbfe1ca35e40dd2afd973f7ebb1605f4d151edc63dff8 \ + --hash=sha256:f0a547e4350d1fa32624d3eab51eff8cf329f4cae110b4ea0402486b1da8be40 \ + --hash=sha256:f39a2a5ded23b9b9194ec45740dce57177b80f86c6d8eba953d3ff1a25c97766 \ + --hash=sha256:f3a0bda83c18195c361b5500377d0767749f128564ca95b42c8849fd475bb327 \ + --hash=sha256:f744b5eb1469bf92dd143d36570d2bdbbdc88fe5cb0b5405e53dd34f479cbd8a \ + --hash=sha256:f9f0bbfb6787b97c51516f3ccf97737d504db5d239ad44527673b81f598b84ab \ + --hash=sha256:fa9720e56663cc3649d62b4b5f3145e94b8f5611e8a8e1b46507777249d46aad \ + --hash=sha256:fb6ec40cef63b1922083d33bfef2f91fc0b0bc07b5b09bfee0b0f1717d558292 \ + --hash=sha256:fe5231e8afd069c742ac5b4f96344a0fe4aff52df8e53ef87faebf77f827822c +requests-toolbelt==1.0.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6 \ + --hash=sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06 +requests==2.32.3 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 +secretstorage==3.3.3 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "linux" \ + --hash=sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77 \ + --hash=sha256:f356e6628222568e3af06f2eba8df495efa13b3b63081dafd4f7d9a7b7bc9f99 +setuptools==75.1.0 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2 \ + --hash=sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538 +shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686 \ + --hash=sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" \ + --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ + --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f +tomlkit==0.13.2 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde \ + --hash=sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79 +trove-classifiers==2024.9.12 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:4b46b3e134a4d01999ac5bc6e528afcc10cc48f0f724f185f267e276005768f4 \ + --hash=sha256:f88a27a892891c87c5f8bbdf110710ae9e0a4725ea8e0fb45f1bcadf088a491f +urllib3==2.2.3 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac \ + --hash=sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9 +virtualenv==20.26.6 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 +xattr==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "darwin" \ + --hash=sha256:00d2b415cf9d6a24112d019e721aa2a85652f7bbc9f3b9574b2d1cd8668eb491 \ + --hash=sha256:0683dae7609f7280b0c89774d00b5957e6ffcb181c6019c46632b389706b77e6 \ + --hash=sha256:08f61cbed52dc6f7c181455826a9ff1e375ad86f67dd9d5eb7663574abb32451 \ + --hash=sha256:0a9c431b0e66516a078125e9a273251d4b8e5ba84fe644b619f2725050d688a0 \ + --hash=sha256:0f06e0c1e4d06b4e0e49aaa1184b6f0e81c3758c2e8365597918054890763b53 \ + --hash=sha256:1a5921ea3313cc1c57f2f53b63ea8ca9a91e48f4cc7ebec057d2447ec82c7efe \ + --hash=sha256:23705c7079b05761ff2fa778ad17396e7599c8759401abc05b312dfb3bc99f69 \ + --hash=sha256:24d97f0d28f63695e3344ffdabca9fcc30c33e5c8ccc198c7524361a98d526f2 \ + --hash=sha256:27272afeba8422f2a9d27e1080a9a7b807394e88cce73db9ed8d2dde3afcfb87 \ + --hash=sha256:46a641ac038a9f53d2f696716147ca4dbd6a01998dc9cd4bc628801bc0df7f4d \ + --hash=sha256:47a3bdfe034b4fdb70e5941d97037405e3904accc28e10dbef6d1c9061fb6fd7 \ + --hash=sha256:4cb70c16e7c3ae6ba0ab6c6835c8448c61d8caf43ea63b813af1f4dbe83dd156 \ + --hash=sha256:54cb15cd94e5ef8a0ef02309f1bf973ba0e13c11e87686e983f371948cfee6af \ + --hash=sha256:6461a43b585e5f2e049b39bcbfcb6391bfef3c5118231f1b15d10bdb89ef17fe \ + --hash=sha256:6480589c1dac7785d1f851347a32c4a97305937bf7b488b857fe8b28a25de9e9 \ + --hash=sha256:687e7d18611ef8d84a6ecd8f4d1ab6757500c1302f4c2046ce0aa3585e13da3f \ + --hash=sha256:6881b120f9a4b36ccd8a28d933bc0f6e1de67218b6ce6e66874e0280fc006844 \ + --hash=sha256:6ad47d89968c9097900607457a0c89160b4771601d813e769f68263755516065 \ + --hash=sha256:78b377832dd0ee408f9f121a354082c6346960f7b6b1480483ed0618b1912120 \ + --hash=sha256:793c01deaadac50926c0e1481702133260c7cb5e62116762f6fe1543d07b826f \ + --hash=sha256:7a92aff66c43fa3e44cbeab7cbeee66266c91178a0f595e044bf3ce51485743b \ + --hash=sha256:7e4ca0956fd11679bb2e0c0d6b9cdc0f25470cc00d8da173bb7656cc9a9cf104 \ + --hash=sha256:83652910ef6a368b77b00825ad67815e5c92bfab551a848ca66e9981d14a7519 \ + --hash=sha256:9013f290387f1ac90bccbb1926555ca9aef75651271098d99217284d9e010f7c \ + --hash=sha256:918e1f83f2e8a072da2671eac710871ee5af337e9bf8554b5ce7f20cdb113186 \ + --hash=sha256:96ca300c0acca4f0cddd2332bb860ef58e1465d376364f0e72a1823fdd58e90d \ + --hash=sha256:9b1664edf003153ac8d1911e83a0fc60db1b1b374ee8ac943f215f93754a1102 \ + --hash=sha256:9c5a78c7558989492c4cb7242e490ffb03482437bf782967dfff114e44242343 \ + --hash=sha256:9d4f71b673339aeaae1f6ea9ef8ea6c9643c8cd0df5003b9a0eaa75403e2e06c \ + --hash=sha256:9dcd5dfbcee73c7be057676ecb900cabb46c691aff4397bf48c579ffb30bb963 \ + --hash=sha256:a20de1c47b5cd7b47da61799a3b34e11e5815d716299351f82a88627a43f9a96 \ + --hash=sha256:afacebbc1fa519f41728f8746a92da891c7755e6745164bd0d5739face318e86 \ + --hash=sha256:b0d73150f2f9655b4da01c2369eb33a294b7f9d56eccb089819eafdbeb99f896 \ + --hash=sha256:b489b7916f239100956ea0b39c504f3c3a00258ba65677e4c8ba1bd0b5513446 \ + --hash=sha256:b6ceb9efe0657a982ccb8b8a2efe96b690891779584c901d2f920784e5d20ae3 \ + --hash=sha256:b735ac2625a4fc2c9343b19f806793db6494336338537d2911c8ee4c390dda46 \ + --hash=sha256:caab2c2986c30f92301f12e9c50415d324412e8e6a739a52a603c3e6a54b3610 \ + --hash=sha256:ccab735d0632fe71f7d72e72adf886f45c18b7787430467ce0070207882cfe25 \ + --hash=sha256:cd11e917f5b89f2a0ad639d9875943806c6c9309a3dd02da5a3e8ef92db7bed9 \ + --hash=sha256:cebcf8a303a44fbc439b68321408af7267507c0d8643229dbb107f6c132d389c \ + --hash=sha256:d1059b2f726e2702c8bbf9bbf369acfc042202a4cc576c2dec6791234ad5e948 \ + --hash=sha256:d1418705f253b6b6a7224b69773842cac83fcbcd12870354b6e11dd1cd54630f \ + --hash=sha256:d44e8f955218638c9ab222eed21e9bd9ab430d296caf2176fb37abe69a714e5c \ + --hash=sha256:d6eb7d5f281014cd44e2d847a9107491af1bf3087f5afeded75ed3e37ec87239 \ + --hash=sha256:dab29d9288aa28e68a6f355ddfc3f0a7342b40c9012798829f3e7bd765e85c2c \ + --hash=sha256:dba4f80b9855cc98513ddf22b7ad8551bc448c70d3147799ea4f6c0b758fb466 \ + --hash=sha256:dc53cab265f6e8449bd683d5ee3bc5a191e6dd940736f3de1a188e6da66b0653 \ + --hash=sha256:dd43978966de3baf4aea367c99ffa102b289d6c2ea5f3d9ce34a203dc2f2ab73 \ + --hash=sha256:dda2684228798e937a7c29b0e1c7ef3d70e2b85390a69b42a1c61b2039ba81de \ + --hash=sha256:ded771eaf27bb4eb3c64c0d09866460ee8801d81dc21097269cf495b3cac8657 \ + --hash=sha256:e0c80bbf55339c93770fc294b4b6586b5bf8e85ec00a4c2d585c33dbd84b5006 \ + --hash=sha256:e189e440bcd04ccaad0474720abee6ee64890823ec0db361fb0a4fb5e843a1bf \ + --hash=sha256:e2255f36ebf2cb2dbf772a7437ad870836b7396e60517211834cf66ce678b595 \ + --hash=sha256:ef2fa0f85458736178fd3dcfeb09c3cf423f0843313e25391db2cfd1acec8888 \ + --hash=sha256:f6ad2a7bd5e6cf71d4a862413234a067cf158ca0ae94a40d4b87b98b62808498 \ + --hash=sha256:fa6a7af7a4ada43f15ccc58b6f9adcdbff4c36ba040013d2681e589e07ae280a \ + --hash=sha256:fecbf3b05043ed3487a28190dec3e4c4d879b2fcec0e30bafd8ec5d4b6043630 \ + --hash=sha256:ff6223a854229055e803c2ad0c0ea9a6da50c6be30d92c198cf5f9f28819a921 +zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.12" \ + --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ + --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 +pip==24.2 --hash=sha256:2cd581cf58ab7fcfca4ce8efa6dcacd0de5bf8d0a3eb9ec927e07405f4d9e2a2 diff --git a/docker/python/bootstrap/lockfiles/requirements-3.9.txt b/docker/python/bootstrap/lockfiles/requirements-3.9.txt new file mode 100644 index 000000000000..43a3c2405739 --- /dev/null +++ b/docker/python/bootstrap/lockfiles/requirements-3.9.txt @@ -0,0 +1,3 @@ +pip +poetry +setuptools diff --git a/docs/how_to/dev/setup_rpc_system.rst b/docs/how_to/dev/setup_rpc_system.rst index 0131619b71d2..f61b7477f5c0 100644 --- a/docs/how_to/dev/setup_rpc_system.rst +++ b/docs/how_to/dev/setup_rpc_system.rst @@ -185,7 +185,7 @@ Troubleshooting The package ``numpy`` is imported in some Python files which RPC server dependent on, and eliminating the import relationship is difficult, for some devices cross compiling ``numpy`` is very hard to do too. -But acturally the TVM runtime doesn't really dependent on ``numpy``, so a very simple workaround is create a dummy ``numpy``, just need to copy the below content into a file named ``numpy.py`` and place it into directory like ``/usr/local/lib/python3.8/site-packages``. +But acturally the TVM runtime doesn't really dependent on ``numpy``, so a very simple workaround is create a dummy ``numpy``, just need to copy the below content into a file named ``numpy.py`` and place it into directory like ``/usr/local/lib/python3.9/site-packages``. .. code-block:: python @@ -242,4 +242,4 @@ But acturally the TVM runtime doesn't really dependent on ``numpy``, so a very s 2. The lack of ``cloudpickle`` on device machine caused the RPC server can't be launched. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Because ``cloudpickle`` package is a pure Python package, so just copying it from other machine to the directory like ``/usr/local/lib/python3.8/site-packages`` of the device machine will resolve the problem. +Because ``cloudpickle`` package is a pure Python package, so just copying it from other machine to the directory like ``/usr/local/lib/python3.9/site-packages`` of the device machine will resolve the problem. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4127266da7e2..be88e234634f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,7 +16,7 @@ # under the License. """The TensorIR schedule class""" import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -65,8 +65,11 @@ def __init__(self) -> None: RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name -# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8 -_ERROR_RENDER_LEVEL: Dict[str, int] = {"detail": 0, "fast": 1, "none": 2} +_ERROR_RENDER_LEVEL: Dict[Literal["detail", "fast", "none"], int] = { + "detail": 0, + "fast": 1, + "none": 2, +} def _parse_error_render_level(error_render_level: str) -> int: From dc2c5a28c9132aa314cca237ffbe32e1bad8dd2a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 3 Oct 2024 06:50:45 -0700 Subject: [PATCH 186/202] [TVMScript][TIR] Add source kernel intetration via call_kernel (#17434) * [TVMScript][TIR] Add source kernel intetration via call_kernel * lint * lint --- .../script/ir_builder/tir/external_kernel.py | 62 ++++++++++- .../relax/test_tir_call_source_kernel.py | 100 ++++++++++++++++++ 2 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 tests/python/relax/test_tir_call_source_kernel.py diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py b/python/tvm/script/ir_builder/tir/external_kernel.py index 8c2467fad330..405e1e6cbf93 100644 --- a/python/tvm/script/ir_builder/tir/external_kernel.py +++ b/python/tvm/script/ir_builder/tir/external_kernel.py @@ -18,14 +18,16 @@ import json import logging import tempfile +from pathlib import Path from typing import Any, Dict, List, Tuple, Union from tvm import __version__ as tvm_version from tvm import tir -from tvm.runtime import Module, load_module +from tvm.runtime import Module, load_module, const +from tvm.contrib import nvcc -class BaseKernel: +class BaseKernel: # pylint: disable=too-few-public-methods """Base class for external kernels.""" def compile_to_device_module( @@ -91,6 +93,60 @@ def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_n return kernel_module +class SourceKernel(BaseKernel): # pylint: disable=too-few-public-methods + """A kernel from source code.""" + + def __init__(self, source_code: str): + self.source_code = source_code + + def compile_to_device_module( # pylint: disable=arguments-differ + self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], **kwargs: Dict[str, Any] + ) -> Tuple[str, Module, List[Any]]: + """Compile the kernel to a device module.""" + from tvm.relax.frontend.nn import SourceModule # pylint: disable=import-outside-toplevel + + kernel_name = kwargs["kernel_name"] + assert len(grid) == 2, ( + "grid should be two list of integers, representing the dimension of " + "['blockIdx.x', 'blockIdx.y', 'blockIdx.z'] and " + "['threadIdx.x', 'threadIdx.y', 'threadIdx.z']" + ) + assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1], (list, tuple)) + launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][: len(grid[0])] + [ + "threadIdx.x", + "threadIdx.y", + "threadIdx.z", + ][: len(grid[1])] + runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg in args] + kernel_arg_types = [arg.dtype for arg in runtime_args] + runtime_args = runtime_args + list(grid[0]) + list(grid[1]) + + # Reuse compilation path from SourceModule + compile_options = SourceModule.get_compile_options("cu") + source_code = self.source_code + try: + source_path = Path(source_code) + if source_path.is_file(): + with open(source_path, "r") as f: + source_code = f.read() + except: # pylint: disable=bare-except + pass + + with tempfile.TemporaryDirectory() as temp_dir: + ptx_path = f"{temp_dir}/{kernel_name}.ptx" + nvcc.compile_cuda( + source_code, target_format="ptx", options=compile_options, path_target=ptx_path + ) + with open(ptx_path, "r") as f: + ptx = f.read() + + kernel_module = self._create_cuda_module( + ptx, kernel_arg_types, launch_param_tags, kernel_name + ) + + return kernel_name, kernel_module, runtime_args + + def call_kernel( kernel, launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]], @@ -123,6 +179,8 @@ def call_kernel( from .triton import TritonKernel # pylint: disable=import-outside-toplevel kernel = TritonKernel(kernel) + elif kernel_type == "builtins.str": + kernel = SourceKernel(kernel) else: raise ValueError("Unsupported kernel type {}".format(kernel_type)) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py new file mode 100644 index 000000000000..9a877ad35f8f --- /dev/null +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import tir as T, ir as I, relax as R + +add_cuda_source = """ +extern "C" __global__ void add_kernel(float* x, float* y, float* output, int n_elements) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_elements) { + output[i] = x[i] + y[i]; + } +} +""" + + +@tvm.testing.requires_cuda +def test_tir_call_source_kernel(): + @I.ir_module + class Module: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None: + T.func_attr({"global_symbol": "add"}) + m = T.int64() + x = T.match_buffer(x_handle, (m,), "float32") + y = T.match_buffer(y_handle, (m,), "float32") + output = T.match_buffer(output_handle, (m,), "float32") + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + BLOCK_SIZE = T.meta_var(64) + T.call_kernel( + add_cuda_source, + ((T.ceildiv(m, BLOCK_SIZE),), (BLOCK_SIZE,)), + x.data, + y.data, + output.data, + m, + kernel_name="add_kernel", + ) + + @R.function + def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): + m = T.int64() + with R.dataflow(): + output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32")) + R.output(output) + return output + + @I.ir_module + class Parsed: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): + m = T.int64() + x = T.match_buffer(x_handle, (m,)) + y = T.match_buffer(y_handle, (m,)) + output = T.match_buffer(output_handle, (m,)) + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + T.call_packed( + "add_kernel", + x.data, + y.data, + output.data, + m, + (m + T.int64(64) - T.int64(1)) // T.int64(64), + 64, + ) + + tvm.ir.assert_structural_equal(Module["add"], Parsed["add"]) + assert len(Module.get_attr("external_mods")) == 1 + + device = tvm.cuda(0) + x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + output_np = x_nd.numpy() + y_nd.numpy() + + with tvm.target.Target("cuda"): + lib = relax.build(Module) + output_nd = tvm.runtime.relax_vm.VirtualMachine(lib, device)["main"](x_nd, y_nd) + tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5) From 79abc0356ee66f3dbdd8bde3cbfcbf88a2ed746e Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 3 Oct 2024 19:20:58 +0530 Subject: [PATCH 187/202] [KVCACHE] Improved schedule for prefill attention (#17432) * [KVCACHE] Improved schedule for prefill attention Improvements Added Tranpose to K for better Vectorization during Matmul. Improved Load Schedule. Improved a bit more than 2x is most cases. Llama-2 7B observation -------kernel----------------baseline----------optimized- ---batch_prefill_ragged_kv----15 ms-------------7.1 ms * Update kv_cache.py --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 60 ++++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 9b16fc2fbfee..fd866ae06c16 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -925,8 +925,12 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 256 if H_kv < 8 else 512 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + # Keeping lower thread limit for this kernel on adreno target + # to avoid register spill + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -1570,7 +1574,11 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = ( + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + d, + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + ) # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1580,6 +1588,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_z = 8 num_warps = 2 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes + NUM_BLKS = group_size * 8 + # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches @@ -1708,8 +1722,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lz, ly in T.grid(tile_z, tile_y): with T.block("K_load"): i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( @@ -1824,6 +1836,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) + get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] + + def get_vecsize(extent): + return min(LOAD_VEC, (extent & ~(extent - 1))) + + def getxy_vecsize(x, y, t): + assert (x * y) % t == 0 + return min(get_vecsize(y), get_vecsize(x * y // t)) def get_tile_size(x, y, t): cnt = (x * y) // t @@ -1837,26 +1857,37 @@ def get_tile_size(x, y, t): def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) + x_extent, y_extent = get_extent(loop_x, loop_y) + vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) + yo, yv = sch.split(loop_y, [None, vec_size]) + yo_extent = y_extent // vec_size + tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) + xo, xi = sch.split(loop_x, [tile_x, None]) + yo, yi = sch.split(yo, [tile_y, None]) + sch.reorder(xi, yi, xo, yo) + t = sch.fuse(xi, yi) + ty, tx = sch.split(t, [num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) + sch.vectorize(yv) def apply_to_so_ewise(sch: tir.Schedule, block, tile): loop_x, loop_y = sch.get_loops(block)[-2:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) + sch.unroll(xi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) t = sch.fuse(xo, yo) ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) @@ -1872,6 +1903,12 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.reorder(ko, xi, yi, ki) else: sch.reorder(ko, ki, xi, yi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) + sch.unroll(xi) + sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block): @@ -1880,6 +1917,7 @@ def apply_to_md(sch, block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") + sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) From 9fdb86d3f6bccc41a772328b5b0442908bc9f9a9 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 3 Oct 2024 22:36:55 +0800 Subject: [PATCH 188/202] [Relax][ONNX] Expand op support for ONNX frontend (#17427) * [Relax][ONNX] Expand op support for ONNX frontend This PR adds a variety of ONNX ops to the Relax frontend, including: - Acos - Acosh - And - Asin - Asinh - Atan - Atanh - BitwiseAnd - BitwiseOr - BitwiseXor - Ceil - ConcatFromSequence - ConvTranspose - Cosh - DepthToSpace - FastGelu - Floor - GlobalLpPool - GlobalMaxPool - GreaterOrEqual - IsInf - IsNaN - LeakyRelu - LogSoftmax - MaxUnpool - Mean - MeanVarianceNormalization - Mish - Or - PRelu - Round - Scatter - ScatterElements - Selu - SequenceAt - SequenceConstruct - SequenceEmpty - SequenceErase - SequenceInsert - SequenceLength - Shrink - Sinh - Size - Softplus - Softsign - SpaceToDepth - SplitToSequence - Tan - ThresholdedRelu - TopK - Unique - Xor Also remains a few ops that are not supported yet, see the commented out ops in the ONNX frontend. * lint * lint * lint * update for ci --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 1302 +++++++++++++---- python/tvm/relax/op/set.py | 8 +- python/tvm/relax/transform/legalize_ops/nn.py | 9 +- tests/python/relax/test_frontend_onnx.py | 664 +++++++-- tests/python/relax/test_relax_operators.py | 2 +- .../relax/test_transform_legalize_ops_nn.py | 47 + 6 files changed, 1617 insertions(+), 415 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 462d1cf92c01..5777f51fe296 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -34,14 +34,15 @@ Not all TVM kernels currently support dynamic shapes, please file an issue on github.com/apache/tvm/issues if you hit an error with dynamic kernels. """ +import math import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as _np import onnx.onnx_ml_pb2 import tvm -from tvm import relax, tir, topi +from tvm import TVMError, relax, tir, topi from tvm.ir import IRModule from tvm.ir.supply import NameSupply from tvm.tir.generic import cast @@ -236,28 +237,176 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) -class Div(OnnxOpConverter): - """Converts an onnx Div node into an equivalent Relax expression.""" +class BinaryBase(OnnxOpConverter): + """Converts an onnx BinaryBase node into an equivalent Relax expression.""" + + numpy_op: Callable = None + relax_op: Callable = None @classmethod - def _impl_v14(cls, bb, inputs, attr, params): + def _impl_v1(cls, bb, inputs, attr, params): + if cls.numpy_op is None or cls.relax_op is None: + raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() / inputs[1].data.numpy() + output = cls.numpy_op( # pylint: disable=not-callable + inputs[0].data.numpy(), inputs[1].data.numpy() + ) return relax.const(output, inputs[0].struct_info.dtype) if any([isinstance(inp, relax.PrimValue) for inp in inputs]): x = ( - int(inputs[0].value) + _np.array(inputs[0].value) if isinstance(inputs[0], relax.PrimValue) else inputs[0].data.numpy() ) y = ( - int(inputs[1].value) + _np.array(inputs[0].value) if isinstance(inputs[1], relax.PrimValue) else inputs[1].data.numpy() ) - return relax.PrimValue(int(x / y)) + return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable + + return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable + + +class Add(BinaryBase): + """Converts an onnx Add node into an equivalent Relax expression.""" + + numpy_op = _np.add + relax_op = relax.op.add + + +class Sub(BinaryBase): + """Converts an onnx Sub node into an equivalent Relax expression.""" + + numpy_op = _np.subtract + relax_op = relax.op.subtract + + +class Mul(BinaryBase): + """Converts an onnx Mul node into an equivalent Relax expression.""" + + numpy_op = _np.multiply + relax_op = relax.op.multiply + + +class Div(BinaryBase): + """Converts an onnx Div node into an equivalent Relax expression.""" + + numpy_op = _np.divide + relax_op = relax.op.divide + + +class Pow(BinaryBase): + """Converts an onnx Pow node into an equivalent Relax expression.""" + + numpy_op = _np.power + relax_op = relax.op.power + + +class And(BinaryBase): + """Converts an onnx And node into an equivalent Relax expression.""" + + numpy_op = _np.logical_and + relax_op = relax.op.logical_and - return relax.op.divide(inputs[0], inputs[1]) + +class Or(BinaryBase): + """Converts an onnx Or node into an equivalent Relax expression.""" + + numpy_op = _np.logical_or + relax_op = relax.op.logical_or + + +class Xor(BinaryBase): + """Converts an onnx Xor node into an equivalent Relax expression.""" + + numpy_op = _np.logical_xor + relax_op = relax.op.logical_xor + + +class Less(BinaryBase): + """Converts an onnx Less node into an equivalent Relax expression.""" + + numpy_op = _np.less + relax_op = relax.op.less + + +class LessOrEqual(BinaryBase): + """Converts an onnx LessEqual node into an equivalent Relax expression.""" + + numpy_op = _np.less_equal + relax_op = relax.op.less_equal + + +class Greater(BinaryBase): + """Converts an onnx Greater node into an equivalent Relax expression.""" + + numpy_op = _np.greater + relax_op = relax.op.greater + + +class GreaterOrEqual(BinaryBase): + """Converts an onnx GreaterEqual node into an equivalent Relax expression.""" + + numpy_op = _np.greater_equal + relax_op = relax.op.greater_equal + + +class Equal(OnnxOpConverter): + """Converts an onnx Equal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + if all([isinstance(inp, relax.Constant) for inp in inputs]): + output = inputs[0].data.numpy() == inputs[1].data.numpy() + return relax.const(output, output.dtype) + elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp in inputs]): + lhs = get_prim_expr_list(inputs[0]) + rhs = get_prim_expr_list(inputs[1]) + if len(lhs) != len(rhs): + raise ValueError("Cannot compare two tensors with different shapes") + output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)] + return relax.const(output, "bool") + return relax.op.equal(inputs[0], inputs[1]) + + +class BitwiseBase(BinaryBase): + """Converts an onnx BitwiseBase node into an equivalent Relax expression.""" + + @classmethod + def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): + valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] + for num, inp in enumerate(inputs): + if inp.struct_info.dtype not in valid_types: + raise ValueError( + f"Bitwise operations expect all inputs to have integer types, " + f"got {inp.struct_info.dtype} for input {num}" + ) + return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op) + + +class BitwiseAnd(BitwiseBase): + """Converts an onnx BitwiseAnd node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and) + + +class BitwiseOr(BitwiseBase): + """Converts an onnx BitwiseOr node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or) + + +class BitwiseXor(BitwiseBase): + """Converts an onnx BitwiseXor node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor) class Sigmoid(OnnxOpConverter): @@ -277,6 +426,15 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.softmax(inputs[0], axis=axis) +class LogSoftmax(OnnxOpConverter): + """Converts an onnx LogSoftmax node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", -1) + return relax.op.nn.log_softmax(inputs[0], axis=axis) + + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" @@ -375,67 +533,6 @@ def is_shape_like(x: Any) -> bool: return relax.op.concat(inputs, axis=axis) -class Add(OnnxOpConverter): - """Convert an onnx Add node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() + inputs[1].data.numpy() - return relax.const(output, output.dtype) - # If primvalues are involved, handle them directly. - if any([isinstance(inp, relax.PrimValue) for inp in inputs]): - x = ( - int(inputs[0].value) - if isinstance(inputs[0], relax.PrimValue) - else inputs[0].data.numpy() - ) - y = ( - int(inputs[1].value) - if isinstance(inputs[1], relax.PrimValue) - else inputs[1].data.numpy() - ) - return relax.PrimValue(int(x + y)) - return relax.op.add(inputs[0], inputs[1]) - - -class Sum(OnnxOpConverter): - """Convert an onnx Sum node into an equivalent Relax expression.""" - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - for in_index in range(len(inputs) - 1): - inputs[in_index + 1] = relax.op.add(inputs[in_index], inputs[in_index + 1]) - - return inputs[len(inputs) - 1] - - -class Mul(OnnxOpConverter): - """Convert an onnx Mul node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - # When all inputs are constant, directly multiply. - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() * inputs[1].data.numpy() - return relax.const(output, output.dtype) - # If primvalues are involved, handle them directly. - if any([isinstance(inp, relax.PrimValue) for inp in inputs]): - x = ( - int(inputs[0].value) - if isinstance(inputs[0], relax.PrimValue) - else inputs[0].data.numpy() - ) - y = ( - int(inputs[1].value) - if isinstance(inputs[1], relax.PrimValue) - else inputs[1].data.numpy() - ) - return relax.PrimValue(int(x * y)) - - return relax.op.multiply(inputs[0], inputs[1]) - - class Cast(OnnxOpConverter): """Convert an onnx Cast node into an equivalent Relax expression.""" @@ -482,8 +579,38 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) - # TODO(jwfromm) Make relax.take work with other indices shape. - return bb.emit_te(topi.take, data, indices, axis) + return relax.op.take(data, indices, axis) + + +class Scatter(OnnxOpConverter): + """Convert an onnx Scatter node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + raise ValueError("Scatter is deprecated in ONNX 11") + + +class ScatterElements(OnnxOpConverter): + """Convert an onnx ScatterElements node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) + + +class Size(OnnxOpConverter): + """Convert an onnx Size node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + # TODO(tvm-team): add native support for size op + return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) class Gemm(OnnxOpConverter): @@ -542,29 +669,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return out -class Gelu(OnnxOpConverter): - """Operator converter for Gelu from Microsoft onnxruntime contrib opset. - - gelu(x) = 0.5x(1 + erf(x/sqrt(2))) - """ - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - return relax.op.nn.gelu(inputs[0]) - - -class BiasGelu(OnnxOpConverter): - """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. - - bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) - """ - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - inp = relax.op.add(inputs[0], inputs[1]) - return relax.op.nn.gelu(inp) - - class Where(OnnxOpConverter): """Convert an onnx Where node into an equivalent Relax expression.""" @@ -605,24 +709,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return results -class Equal(OnnxOpConverter): - """Converts an onnx Equal node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() == inputs[1].data.numpy() - return relax.const(output, output.dtype) - elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp in inputs]): - lhs = get_prim_expr_list(inputs[0]) - rhs = get_prim_expr_list(inputs[1]) - if len(lhs) != len(rhs): - raise ValueError("Cannot compare two tensors with different shapes") - output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)] - return relax.const(output, "bool") - return relax.op.equal(inputs[0], inputs[1]) - - class Shape(OnnxOpConverter): """Converts an onnx Equal node into an equivalent Relax expression.""" @@ -643,22 +729,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return data_info.shape -class Tanh(OnnxOpConverter): - """Converts an onnx Tanh node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.tanh(inputs[0]) - - -class Sqrt(OnnxOpConverter): - """Converts an onnx Sqrt node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.sqrt(inputs[0]) - - class Trilu(OnnxOpConverter): """Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s) @@ -691,12 +761,157 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.relu(inputs[0]) -class Pow(OnnxOpConverter): - """Converts an onnx Pow node into an equivalent Relax expression.""" +class Elu(OnnxOpConverter): + """Converts an onnx Elu node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.power(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + alpha = float(attr.get("alpha", 1.0)) + return relax.expr.const(-alpha) * relax.op.nn.relu( + relax.expr.const(1.0) - relax.op.exp(inputs[0]) + ) + relax.op.nn.relu(inputs[0]) + + +class Selu(OnnxOpConverter): + """Converts an onnx Selu node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + alpha = attr.get("alpha", 1.67326319217681884765625) + gamma = attr.get("gamma", 1.05070102214813232421875) + return relax.const(gamma) * ( + relax.const(-alpha) * relax.op.nn.relu(relax.const(1.0) - relax.op.exp(inputs[0])) + + relax.op.nn.relu(inputs[0]) + ) + + +class Mish(OnnxOpConverter): + """Converts an onnx Mish node into an equivalent Relax expression. + + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) + """ + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + dtype = inputs[0].checked_type.dtype + return inputs[0] * relax.op.tanh( + relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0])) + ) + + +class PRelu(OnnxOpConverter): + """Converts an onnx PRelu node into an equivalent Relax expression. + + f(x) = slope * x for x < 0, x for x >= 0 + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + slope = inputs[1] + # TODO(tvm-team): Should add a new op for this. + return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope) + + +class ThresholdedRelu(OnnxOpConverter): + """Converts an onnx ThresholdedRelu node into an equivalent Relax expression. + + f(x) = x for x > alpha, 0 otherwise + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + alpha = attr.get("alpha", 1.0) + return relax.op.greater(x, relax.const(alpha)).astype("float32") * x + + +class LeakyRelu(OnnxOpConverter): + """Converts an onnx LeakyRelu node into an equivalent Relax expression. + + f(x) = x for x > 0, alpha * x otherwise + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + alpha = attr.get("alpha", 0.01) + return relax.op.nn.leakyrelu(x, alpha) + + +class Gelu(OnnxOpConverter): + """Operator converter for Gelu from Microsoft onnxruntime contrib opset. + + gelu(x) = 0.5x(1 + erf(x/sqrt(2))) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.nn.gelu(inputs[0]) + + +class FastGelu(OnnxOpConverter): + """Operator converter for FastGelu from Microsoft onnxruntime contrib opset. + + fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3))) + = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3))) + = 0.5x(1 + tanh(c1 * x + c2 * x^3))) + , where + c1 = sqrt(2/pi) + c2 = 0.044715 * sqrt(2/pi) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if inputs[1]: + bias = inputs[1] + bias_shape = bias.struct_info.shape + assert len(bias_shape) == 1, "bias term must be a 1D tensor" + x += bias + + # Declare consts + const_dtype = x.struct_info.dtype + half = relax.const(0.5, dtype=const_dtype) + one = relax.const(1.0, dtype=const_dtype) + const1 = relax.const(math.sqrt(2 / math.pi), dtype=const_dtype) + const2 = relax.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) + + # Compute FastGelu + term1 = relax.op.multiply(half, x) + term2 = relax.op.multiply(const1, x) + term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3, const_dtype))) + tanh = relax.op.tanh(relax.op.add(term2, term3)) + return relax.op.multiply(term1, relax.op.add(one, tanh)) + + +class BiasGelu(OnnxOpConverter): + """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. + + bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + inp = relax.op.add(inputs[0], inputs[1]) + return relax.op.nn.gelu(inp) + + +class Shrink(OnnxOpConverter): + """Converts an onnx Shrink node into an equivalent Relax expression. + + f(x) = x + bias if x > lambd, x - bias if x < -lambd, 0 otherwise + """ + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + x = inputs[0] + dtype = x.struct_info.dtype + lambd = relax.const(attr.get("lambd", 0.5), dtype) + bias = relax.const(attr.get("bias", 0.0), dtype) + zeros = relax.op.zeros_like(x) + return relax.op.where(x > lambd, x - bias, zeros) + relax.op.where( + x < -lambd, x + bias, zeros + ) class Conv(OnnxOpConverter): @@ -730,21 +945,55 @@ def _impl_v11(cls, bb, inputs, attr, params): weight=inputs[1], strides=attr.get("strides", 1), padding=attr.get("pads", 0), - dilation=attr.get("dilation", 1), + dilation=attr.get("dilations", 1), groups=attr.get("group", 1), data_layout=data_layout, kernel_layout=kernel_layout, ) ) if inputs[2] is not None: - bias = relax.op.reshape( - inputs[2], - [1, -1] - + [ - 1, - ] - * (ndim - 2), - ) + bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2)) + conv_out = relax.op.add(conv_out, bias) + + return conv_out + + +class ConvTranspose(OnnxOpConverter): + """Converts an onnx ConvTranspose node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if hasattr(inputs[0].struct_info, "ndim"): + ndim = inputs[0].struct_info.ndim + else: + ndim = len(inputs[0].struct_info.shape) + + if ndim == 3: + op = relax.op.nn.conv1d_transpose + data_layout = "NCW" + kernel_layout = "IOW" + elif ndim == 4: + op = relax.op.nn.conv2d_transpose + data_layout = "NCHW" + kernel_layout = "IOHW" + elif ndim == 5: + raise NotImplementedError("Relax ConvTranspose3d not supported yet") + else: + raise NotImplementedError("Ndim > 5 not supported for convolution.") + + conv_out = op( + data=inputs[0], + weight=inputs[1], + strides=attr.get("strides", 1), + padding=attr.get("pads", 0), + dilation=attr.get("dilations", 1), + groups=attr.get("group", 1), + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + if inputs[2] is not None: + bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2)) conv_out = relax.op.add(conv_out, bias) return conv_out @@ -839,17 +1088,6 @@ def _impl_v9(cls, bb, inputs, attr, params): return relax.op.broadcast_to(relax.const(value, dtype), shape) -class Sub(OnnxOpConverter): - """Converts an onnx Sub node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() - inputs[1].data.numpy() - return relax.const(output, output.dtype) - return relax.op.subtract(inputs[0], inputs[1]) - - class Sin(OnnxOpConverter): """Converts an onnx Sin node into an equivalent Relax expression.""" @@ -858,6 +1096,14 @@ def _impl_v7(cls, bb, inputs, attr, params): return relax.op.sin(inputs[0]) +class Sinh(OnnxOpConverter): + """Converts an onnx Sinh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.sinh(inputs[0]) + + class Cos(OnnxOpConverter): """Converts an onnx Cos node into an equivalent Relax expression.""" @@ -866,6 +1112,78 @@ def _impl_v7(cls, bb, inputs, attr, params): return relax.op.cos(inputs[0]) +class Cosh(OnnxOpConverter): + """Converts an onnx Cosh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.cosh(inputs[0]) + + +class Tan(OnnxOpConverter): + """Converts an onnx Tan node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.tan(inputs[0]) + + +class Tanh(OnnxOpConverter): + """Converts an onnx Tanh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.tanh(inputs[0]) + + +class Acos(OnnxOpConverter): + """Converts an onnx Acos node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.acos(inputs[0]) + + +class Acosh(OnnxOpConverter): + """Converts an onnx Acosh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.acosh(inputs[0]) + + +class Asin(OnnxOpConverter): + """Converts an onnx Asin node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.asin(inputs[0]) + + +class Asinh(OnnxOpConverter): + """Converts an onnx Asinh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.asinh(inputs[0]) + + +class Atan(OnnxOpConverter): + """Converts an onnx Atan node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.atan(inputs[0]) + + +class Atanh(OnnxOpConverter): + """Converts an onnx Atanh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.atanh(inputs[0]) + + class Neg(OnnxOpConverter): """Converts an onnx Neg node into an equivalent Relax expression.""" @@ -877,47 +1195,121 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.negative(inputs[0]) -class Abs(OnnxOpConverter): - """Converts an onnx Abs node into an equivalent Relax expression.""" +class Abs(OnnxOpConverter): + """Converts an onnx Abs node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + if isinstance(inputs[0], relax.Constant): + output = _np.abs(inputs[0].data.numpy()) + return relax.const(output, output.dtype) + return relax.op.abs(inputs[0]) + + +class Reciprocal(OnnxOpConverter): + """Converts an onnx Reciprocal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + input_dtype = inputs[0].struct_info.dtype + return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) + + +class Floor(OnnxOpConverter): + """Converts an onnx Floor node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.floor(inputs[0]) + + +class Ceil(OnnxOpConverter): + """Converts an onnx Ceil node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.ceil(inputs[0]) + + +class Round(OnnxOpConverter): + """Converts an onnx Round node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.round(inputs[0]) + + +class IsInf(OnnxOpConverter): + """Converts an onnx IsInf node into an equivalent Relax expression.""" + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + return relax.op.isinf(inputs[0]) + + +class IsNaN(OnnxOpConverter): + """Converts an onnx IsNaN node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if isinstance(inputs[0], relax.Constant): - output = _np.abs(inputs[0].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.abs(inputs[0]) + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.isnan(inputs[0]) -class Min(OnnxOpConverter): - """Converts an onnx Min node into an equivalent Relax expression.""" +class Sqrt(OnnxOpConverter): + """Converts an onnx Sqrt node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.sqrt(inputs[0]) + + +class MultiInputBase(OnnxOpConverter): + """Converts an onnx MultiInputBase node into an equivalent Relax expression.""" + + numpy_op: Callable = None + relax_op: Callable = None + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if cls.numpy_op is None or cls.relax_op is None: + raise NotImplementedError("numpy_op and relax_op must be defined for MultiInputBase") if all([isinstance(inp, relax.Constant) for inp in inputs]): np_inputs = [inp.data.numpy() for inp in inputs] - output = _np.minimum(*np_inputs) + output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) # Expand inputs, stack them, then perform minimum over the new axis. inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.min(stacked_tensor, axis=0) + return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable + + +class Min(MultiInputBase): + """Converts an onnx Min node into an equivalent Relax expression.""" + + numpy_op = _np.min + relax_op = relax.op.min -class Max(OnnxOpConverter): +class Max(MultiInputBase): """Converts an onnx Max node into an equivalent Relax expression.""" - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - np_inputs = [inp.data.numpy() for inp in inputs] - output = _np.maximum(*np_inputs) - return relax.const(output, output.dtype) + numpy_op = _np.max + relax_op = relax.op.max - # Expand inputs, stack them, then perform maximum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.max(stacked_tensor, axis=0) + +class Mean(MultiInputBase): + """Converts an onnx Mean node into an equivalent Relax expression.""" + + numpy_op = _np.mean + relax_op = relax.op.mean + + +class Sum(MultiInputBase): + """Converts an onnx Sum node into an equivalent Relax expression.""" + + numpy_op = _np.sum + relax_op = relax.op.sum class Log(OnnxOpConverter): @@ -956,26 +1348,22 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.exp(data) -class Less(OnnxOpConverter): - """Converts an onnx Less node into an equivalent Relax expression.""" +class Softplus(OnnxOpConverter): + """Converts an onnx Softplus node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.less(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.less(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + dtype = inputs[0].struct_info.dtype + return relax.op.log(relax.op.exp(inputs[0]) + relax.const(1, dtype=dtype)) -class LessOrEqual(OnnxOpConverter): - """Converts an onnx LessOrEqual node into an equivalent Relax expression.""" +class Softsign(OnnxOpConverter): + """Converts an onnx Softsign node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.less_equal(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.less_equal(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + dtype = inputs[0].struct_info.dtype + return inputs[0] / (relax.op.abs(inputs[0]) + relax.const(1, dtype=dtype)) class Split(OnnxOpConverter): @@ -1456,6 +1844,20 @@ def _impl_v15(cls, bb, inputs, attr, params): ) +class MeanVarianceNormalization(OnnxOpConverter): + """Converts an onnx MeanVarianceNormalization node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + data = inputs[0] + axis = attr.get("axes", (0, 2, 3)) + data_mean = relax.op.mean(data, axis=axis, keepdims=True) + data_mean_squared = relax.op.power(data_mean, relax.const(2, dtype="float32")) + data_squared = relax.op.power(data, relax.const(2, dtype="float32")) + data_squared_mean = relax.op.mean(data_squared, axis=axis, keepdims=True) + return (data - data_mean) / relax.op.sqrt(data_squared_mean - data_mean_squared) + + class Pool(OnnxOpConverter): """A helper class for pool op converters.""" @@ -1557,16 +1959,79 @@ class GlobalAveragePool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): rank = len(inputs[0].struct_info.shape) - if rank == 3: - return relax.op.nn.adaptive_avg_pool1d(inputs[0], 1) - elif rank == 4: - return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1) - elif rank == 5: - return relax.op.nn.adaptive_avg_pool3d(inputs[0], 1) - raise NotImplementedError( - "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." - % (rank - 2) + axes = list(range(2, rank)) + return relax.op.mean(inputs[0], axis=axes, keepdims=True) + + +class GlobalMaxPool(OnnxOpConverter): + """Converts an onnx GlobalMaxPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + rank = len(inputs[0].struct_info.shape) + axes = list(range(2, rank)) + return relax.op.max(inputs[0], axis=axes, keepdims=True) + + +class GlobalLpPool(OnnxOpConverter): + """Converts an onnx GlobalLpPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v2(cls, bb, inputs, attr, params): + p = attr.get("p", 2.0) + dtype = inputs[0].struct_info.dtype + rank = len(inputs[0].struct_info.shape) + axes = list(range(2, rank)) + x_abs = relax.op.abs(inputs[0]) + x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype)) + x_sum = relax.op.sum(x_p, axes, keepdims=True) + return relax.op.power(x_sum, relax.const(1.0 / p, dtype=dtype)) + + +class MaxUnpool(OnnxOpConverter): + """Converts an onnx MaxUnpool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + data = inputs[0] + indices = inputs[1] + output_shape = inputs[2] + kernel_shape = attr.get("kernel_shape") + pads = attr.get("pads", [0] * len(kernel_shape) * 2) + strides = attr.get("strides", [1] * len(kernel_shape)) + + multiplier = _np.concatenate([[1, 1], list(strides)]) + shape = [v.value for v in data.struct_info.shape] + total_output_shape = multiplier * shape + # Add extra dimensions from kernel size and stride mismatch + total_output_shape += _np.concatenate([[0, 0], list(kernel_shape)], axis=0) + total_output_shape -= _np.concatenate([[0, 0], list(strides)], axis=0) + + if output_shape is not None: + total_output_shape = output_shape + + elif pads is not None: + # Get pads in the proper format for relay. + pads = _np.concatenate([[0, 0, 0, 0], list(pads)], axis=0) + pads = _np.reshape(pads, [-1, 2]) + # Compute the total padding per axis. + total_pad = _np.sum(pads, axis=-1) + # Reversing maxpool means that padding actually makes our output smaller. + total_output_shape = total_output_shape - total_pad + + # Create a tensor of zeros then scatter our data through it. + relax_shape = relax.ShapeExpr(total_output_shape.tolist()) + zeros_tensor = bb.emit(relax.op.zeros(relax_shape, data.struct_info.dtype)) + # We need to flatten all our tensors before scattering. + flat_tensor = relax.op.scatter_elements( + relax.op.reshape(zeros_tensor, [-1]), + relax.op.reshape(indices, [-1]), + relax.op.reshape(data, [-1]), + axis=0, ) + # Reshape our flattened data back to normal. + output = relax.op.reshape(flat_tensor, relax_shape) + return output class Flatten(OnnxOpConverter): @@ -1799,6 +2264,32 @@ def _impl_v12(cls, bb, inputs, attr, params): return relax.op.argmin(data, axis, keepdims) +class TopK(OnnxOpConverter): + """Converts an onnx TopK node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] + k = inputs[1] + if not isinstance(k, relax.Constant): + raise ValueError("TopK k must be a constant") + k = int(k.data.numpy()) + axis = attr.get("axis", -1) + largest = attr.get("largest", 1) + sorted = attr.get("sorted", 1) + if sorted != 1: + raise ValueError("TopK sorted must be 1 for Relax frontend") + + return relax.op.topk(data, k, axis, ret_type="both", largest=largest) + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + data = inputs[0] + k = attr.get("k", 1) + axis = attr.get("axis", -1) + return relax.op.topk(data, k, axis, ret_type="both") + + class SkipLayerNormalization(OnnxOpConverter): """Converts a microsoft contrib SkipLayerNormalization node into a Relax expression.""" @@ -1871,26 +2362,6 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.Tuple([ln, mask_index]) -class Greater(OnnxOpConverter): - """Converts an onnx Greater node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.greater(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.greater(inputs[0], inputs[1]) - - -class Reciprocal(OnnxOpConverter): - """Converts an onnx Reciprocal node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - input_dtype = inputs[0].struct_info.dtype - return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) - - class OneHot(OnnxOpConverter): """Converts an onnx OneHot node into an equivalent Relax expression.""" @@ -1909,15 +2380,16 @@ def _impl_v11(cls, bb, inputs, attr, params): return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) -class Elu(OnnxOpConverter): - """Converts an onnx Elu node into an equivalent Relax expression.""" +class Unique(OnnxOpConverter): + """Converts an onnx Unique node into an equivalent Relax expression.""" @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - alpha = float(attr.get("alpha", 1.0)) - return relax.expr.const(-alpha) * relax.op.nn.relu( - relax.expr.const(1.0) - relax.op.exp(inputs[0]) - ) + relax.op.nn.relu(inputs[0]) + def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] + axis = attr.get("axis", None) + sorted = bool(attr.get("sorted", 1)) + # TODO(tvm-team): Add support for return_index, return_inverse, return_counts + return relax.op.unique(data, sorted=sorted, axis=axis) class HardSigmoid(OnnxOpConverter): @@ -1966,53 +2438,308 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.logical_not(inputs[0]) +class DepthToSpace(OnnxOpConverter): + """Converts an onnx DepthToSpace node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + block_size = int(attr["blocksize"]) + mode = attr.get("mode", b"DCR").decode("utf-8") + b, c, h, w = inputs[0].struct_info.shape + if mode == "DCR": + x = relax.op.reshape( + inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) + ) + x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) + return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) + elif mode == "CRD": + x = relax.op.reshape( + inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) + ) + x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) + return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) + else: + raise ValueError(f"Unsupported mode: {mode}, expected DCR or CRD") + + +class SpaceToDepth(OnnxOpConverter): + """Converts an onnx SpaceToDepth node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + block_size = int(attr["blocksize"]) + b, c, h, w = inputs[0].struct_info.shape + x = relax.op.reshape( + inputs[0], (b, c, h // block_size, block_size, w // block_size, block_size) + ) + x = relax.op.permute_dims(x, [0, 3, 5, 1, 2, 4]) + return relax.op.reshape( + x, (b, c * block_size * block_size, h // block_size, w // block_size) + ) + + +class SequenceConstruct(OnnxOpConverter): + """Operator converter for sequence construction op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Construct a tuple from input tensors. + return relax.Tuple(inputs) + + +class SequenceEmpty(OnnxOpConverter): + """Operator converter for sequence empty op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Construct an empty tuple. + return relax.Tuple([]) + + +class SequenceErase(OnnxOpConverter): + """Operator converter for sequence erase op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Erase tensor from sequence on specified position + input_sequence = inputs[0] + + if len(inputs) == 2: + position = inputs[1] + # Non constant position is not supported. + if isinstance(position, relax.Constant): + position = int(position.data.numpy()) + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + seq_len = len(input_sequence) + if not -seq_len <= position < seq_len: + raise ValueError( + f"Position is out of bounds, expected [-{seq_len}, {seq_len}), got {position}" + ) + + if position < 0: + position = seq_len + position + # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(seq_len) if i != position] + # Create new tuple and return. + return relax.Tuple(tensor_list) + + +class SequenceInsert(OnnxOpConverter): + """Operator converter for sequence insert op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Insert a new tensor into a tuple of tensors. + input_sequence = inputs[0] + tensor_to_insert = inputs[1] + + if len(inputs) == 3: + position = inputs[2] + # Non constant position is not supported. + if isinstance(position, relax.Constant): + position = position.data.numpy() + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + if position < 0: + position = len(input_sequence) + position + 1 + # Convert sequence to a list, insert new tensor, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(len(input_sequence))] + # Insert new tensor. + tensor_list.insert(position, tensor_to_insert) + # Create new tuple and return. + return relax.Tuple(tensor_list) + + +class SequenceLength(OnnxOpConverter): + """Operator converter for sequence length op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Get length of input sequence + return relax.const(len(inputs[0]), dtype="int64") + + +class ConcatFromSequence(OnnxOpConverter): + """Operator converter for sequence concatenation op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + new_axis = attr.get("new_axis", 0) + + if new_axis == 1: + raise NotImplementedError("Insert new axis is not supported yet.") + + return relax.op.concat(inputs[0], axis=axis) + + +class SplitToSequence(OnnxOpConverter): + """Operator converter for split to sequence op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", 1) + + input_tensor = inputs[0] + input_shape = input_tensor.struct_info.shape + + # If split is not provided, we split all values along axis. + if len(inputs) == 1: + split = _np.array(1) + if not keepdims: + raise NotImplementedError("Only keepdims=1 is supported for now") + else: + split = inputs[1] + if not isinstance(split, relax.Constant): + raise ValueError("Only constant split supported for SplitToSequence") + split = split.data.numpy() + + if len(split.shape) == 1 and split.shape[0] > 1: + split = _np.cumsum(split) + split = list(split[:-1]) + else: + chunk_size, dim_size = int(split), input_shape[axis] + if dim_size % chunk_size != 0: + raise ValueError( + f"Dimension of size {dim_size} along axis {axis} must be " + f"evenly divisible by chunk size {chunk_size}" + ) + split = dim_size // chunk_size + + output = relax.op.split(input_tensor, split, axis=axis) + return output + + +class SequenceAt(OnnxOpConverter): + """Operator converter for sequence at op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + input_sequence = inputs[0] + position = inputs[1] + assert isinstance( + position, relax.Constant + ), "Only constant position supported for SequenceAt" + position = int(position.data.numpy()) + return input_sequence[position] + + def _get_convert_map(): return { - "MatMul": MatMul, - "Concat": Concat, + # defs/experimental + # "Optional": Optional_, + # "OptionalHasElement": OptionalHasElement, + # "OptionalGetElement": OptionalGetElement, + # Binary operators "Add": Add, + "Sub": Sub, "Mul": Mul, - "Cast": Cast, + "Div": Div, + # "Mod": Mod, + "Less": Less, + "LessOrEqual": LessOrEqual, + "Greater": Greater, + "GreaterOrEqual": GreaterOrEqual, + "Equal": Equal, + "BitwiseAnd": BitwiseAnd, + "BitwiseOr": BitwiseOr, + "BitwiseXor": BitwiseXor, + # "BitwiseNot": BitwiseNot, + # "BitwiseShift": BitwiseShift, + "And": And, + "Or": Or, + "Xor": Xor, + "Not": Not, + # Unary operators + "Log": Log, + "Exp": Exp, + "Acos": Acos, + "Acosh": Acosh, + "Asin": Asin, + "Asinh": Asinh, + "Atan": Atan, + "Atanh": Atanh, + "Cos": Cos, + "Cosh": Cosh, + "Sin": Sin, + "Sinh": Sinh, + "Tan": Tan, + "Tanh": Tanh, + "Neg": Neg, + "Abs": Abs, + "Reciprocal": Reciprocal, + "Floor": Floor, + "Ceil": Ceil, + "Round": Round, + "IsInf": IsInf, + "IsNaN": IsNaN, + "Sqrt": Sqrt, + "Relu": Relu, + "Selu": Selu, + "Mish": Mish, + "Trilu": Trilu, + "PRelu": PRelu, + "LeakyRelu": LeakyRelu, + "ThresholdedRelu": ThresholdedRelu, + "Elu": Elu, + "Gelu": Gelu, + "FastGelu": FastGelu, + "BiasGelu": BiasGelu, + "HardSigmoid": HardSigmoid, + "HardSwish": HardSwish, + "Sign": Sign, + "Softplus": Softplus, + "Softsign": Softsign, + "Shrink": Shrink, + "Erf": Erf, "Sum": Sum, - "Gather": Gather, + "Min": Min, + "Max": Max, + "Mean": Mean, + "Cast": Cast, "Gemm": Gemm, + "MatMul": MatMul, + # "MatMulInteger": MatMulInteger, + # "MatMulInteger16": MatMulInteger16, "Reshape": Reshape, - "Div": Div, "Sigmoid": Sigmoid, "Softmax": Softmax, + "LogSoftmax": LogSoftmax, + # "Hardmax": Hardmax, "Transpose": Transpose, "Unsqueeze": Unsqueeze, - "Gelu": Gelu, - "BiasGelu": BiasGelu, "Where": Where, + "Concat": Concat, "Clip": Clip, - "Equal": Equal, "Shape": Shape, - "Tanh": Tanh, - "Sqrt": Sqrt, - "Trilu": Trilu, - "Relu": Relu, - "Conv": Conv, "Pow": Pow, - "Erf": Erf, "CumSum": CumSum, "Squeeze": Squeeze, "Constant": Constant, - "Sub": Sub, - "Sin": Sin, - "Cos": Cos, - "Neg": Neg, - "Abs": Abs, - "Min": Min, - "Max": Max, - "Log": Log, - "Exp": Exp, - "Less": Less, - "LessOrEqual": LessOrEqual, + "Gather": Gather, + # "GatherElements": GatherElements, + # "GatherND": GatherND, + "Scatter": Scatter, + "ScatterElements": ScatterElements, + # "ScatterND": ScatterND, + # "Compress": Compress, + "Size": Size, + # "EyeLike": EyeLike, + # Normalization + "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, "SkipLayerNormalization": SkipLayerNormalization, "EmbedLayerNormalization": EmbedLayerNormalization, "InstanceNormalization": InstanceNormalization, + "MeanVarianceNormalization": MeanVarianceNormalization, # defs/reduction "ReduceMax": ReduceMax, "ReduceMin": ReduceMin, @@ -2026,6 +2753,7 @@ def _get_convert_map(): "ReduceL2": ReduceL2, "ArgMax": ArgMax, "ArgMin": ArgMin, + "TopK": TopK, "Expand": Expand, "ConstantOfShape": ConstantOfShape, "Slice": Slice, @@ -2033,23 +2761,42 @@ def _get_convert_map(): "Pad": Pad, "Split": Split, "Tile": Tile, - "BatchNormalization": BatchNormalization, - "MaxPool": MaxPool, "AveragePool": AveragePool, + "MaxPool": MaxPool, + # "LpPool": LpPool, "GlobalAveragePool": GlobalAveragePool, + "GlobalMaxPool": GlobalMaxPool, + "GlobalLpPool": GlobalLpPool, + "MaxUnpool": MaxUnpool, + "Conv": Conv, + "ConvTranspose": ConvTranspose, "Flatten": Flatten, "Identity": Identity, "Resize": Resize, "Einsum": Einsum, "Range": Range, - "Greater": Greater, - "Reciprocal": Reciprocal, "OneHot": OneHot, - "Elu": Elu, - "HardSigmoid": HardSigmoid, - "HardSwish": HardSwish, - "Sign": Sign, - "Not": Not, + "Unique": Unique, + # "NonZero": NonZero, + # "If": If, + # "LRN": LRN, + # "MaxRoiPool": MaxRoiPool, + # "RoiAlign": RoiAlign, + # "NonMaxSuppression": NonMaxSuppression, + # "GridSample": GridSample, + # "Upsample": Upsample, + # others + "DepthToSpace": DepthToSpace, + "SpaceToDepth": SpaceToDepth, + # Sequence operators + "SequenceConstruct": SequenceConstruct, + "SequenceEmpty": SequenceEmpty, + "SequenceErase": SequenceErase, + "SequenceInsert": SequenceInsert, + "SequenceLength": SequenceLength, + "ConcatFromSequence": ConcatFromSequence, + "SplitToSequence": SplitToSequence, + "SequenceAt": SequenceAt, } @@ -2269,6 +3016,14 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Where", "Cast", ] + return_tuple_ops = [ + "SequenceConstruct", + "SequenceEmpty", + "SequenceErase", + "SequenceInsert", + "ConcatFromSequence", + "SplitToSequence", + ] for i, inp in enumerate(inputs): if ( inp is not None @@ -2277,11 +3032,17 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): and op_name not in shape_compatible_ops ): raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.") - op = self._convert_operator(op_name, inputs, attr, self.opset) - # Create struct information for the new operator. - op = self.bb.normalize(op) - - if not isinstance(op, relax.Tuple): + try: + op = self._convert_operator(op_name, inputs, attr, self.opset) + # Create struct information for the new operator. + op = self.bb.normalize(op) + except TVMError as err: + print(f"Error converting operator {op_name}, with inputs: {inputs}") + raise err + + if op_name in return_tuple_ops: + outputs_num = 1 + elif not isinstance(op, relax.Tuple): if isinstance(op.checked_type, tvm.ir.type.TupleType): # This is a var bound to a tuple. We need to unpack it and create # a new tuple. @@ -2299,7 +3060,6 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): ), "Missing outputs during conversion. Expected {} but Got {} in {}.".format( len(outputs), outputs_num, op_name ) - if outputs_num == 1: self._nodes[outputs[0]] = op else: @@ -2346,10 +3106,10 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> Dict[str, def _convert_operator( self, op_name: str, - inputs: List[relax.Function], + inputs: List[relax.Expr], attrs: Dict, opset: int, - ) -> relax.Function: + ) -> relax.Expr: """Convert ONNX operator into a Relax operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. @@ -2386,7 +3146,7 @@ def from_onnx( opset: int = None, keep_params_in_input: bool = False, sanitize_input_names: bool = True, -) -> Tuple[IRModule, Dict]: +) -> IRModule: """Convert a ONNX model into an equivalent Relax Function. ONNX graphs are represented as Python Protobuf objects. @@ -2413,8 +3173,6 @@ def from_onnx( ------- mod : tvm.IRModule The relax module for compilation - params : dict of str to tvm.nd.NDArray - The parameter dict to be used by relax """ # Error if the model version is below 1.1.0 if model.ir_version < 3: diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 4d106ad6d23c..0b86e19ce53f 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -77,7 +77,7 @@ def unique( return_inverse = PrimValue(return_inverse) if isinstance(return_counts, bool): return_counts = PrimValue(return_counts) - if axis and isinstance(axis, int): + if axis is not None and isinstance(axis, int): axis = PrimValue(axis) return _ffi_api.unique( # type: ignore x, sorted, return_index, return_inverse, return_counts, axis @@ -91,6 +91,7 @@ def numpy_unique( return_index: int, return_inverse: int, return_counts: int, + axis: Optional[int] = None, ) -> tvm.nd.array: """Returns the unique elements of the input tensor. @@ -103,8 +104,9 @@ def numpy_unique( raise NotImplementedError("missing support return_inverse or return_counts set to true") x_numpy = x.numpy() # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. - output_sorted_numpy, indices = np.unique(x_numpy, return_index=True) + output_sorted_numpy, indices = np.unique(x_numpy, return_index=True, axis=axis) + if sorted: return tvm.nd.array(output_sorted_numpy) - output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)] + output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 809d231fd30d..8317d4504e1e 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -171,21 +171,16 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr: "and thus cannot be legalized by TOPI" ) return call - if call.attrs.groups != 1: - logging.info( - "TOPI conv1d_transpose does not support groups other than 1, " - "and thus cannot be legalized by TOPI" - ) - return call return bb.call_te( - topi.nn.conv1d_transpose_ncw, + topi.nn.group_conv1d_transpose_ncw, call.args[0], call.args[1], stride=call.attrs.strides, padding=call.attrs.padding, out_dtype=call.struct_info.dtype, output_padding=call.attrs.output_padding, + groups=call.attrs.groups, primfunc_name_hint="conv1d_transpose", ) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0e7cfbd7c093..2837ad2185e9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -21,7 +21,7 @@ This file is a test script to test Relax ONNX frontend coverage. """ -from typing import Dict, Optional +from typing import Dict, List, Literal, Optional import numpy as np import onnx @@ -118,6 +118,7 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) + print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -137,25 +138,31 @@ def check_correctness( vm.invoke_stateful("main") tvm_output = vm.get_outputs("main") # Wrap as a list if there is only one output. - if isinstance(tvm_output, tvm.nd.NDArray): + if len(ort_output) == 1: + # Do not check the output number for TVM + # As for sequence output, the TVM output is a Tuple + # while the ONNX output number is one, which is a list tvm_output = [tvm_output] - # If the output is a shape tuple, convert it to an ndarray for comparison. - if isinstance(tvm_output, tvm.runtime.ShapeTuple): - tvm_output = [tvm.nd.array([int(i) for i in tvm_output])] - tvm_num_outputs = len(tvm_output) - # Shape tuples need to be handled specially. - if isinstance(tvm_output, tvm.runtime.ShapeTuple): - tvm_num_outputs = 1 + def _check_output(tvm_out, ort_out): + if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)): + assert len(tvm_out) == len(ort_out), "Unequal number of outputs" + for tvm_out_i, ort_out_i in zip(tvm_out, ort_out): + _check_output(tvm_out_i, ort_out_i) + elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray): + tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) + elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): + shape_out = tvm.nd.array([int(i) for i in tvm_out]) + tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) + else: + raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}") # Check that number of outputs match. - assert tvm_num_outputs == len(ort_output), "Unequal number of outputs" - + assert len(tvm_output) == len(ort_output), "Unequal number of outputs" for (tvm_out, ort_out) in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. - # Sometimes None is used to indicate an unused output. if ort_out is not None: - tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) + _check_output(tvm_out, ort_out) @pytest.mark.parametrize( @@ -187,35 +194,61 @@ def test_sanitize(input_names, expected_names): assert param.name_hint == expected_names[i] -def verify_unary(op_name, shape, attrs={}, domain=None, dtype=TensorProto.FLOAT): +def verify_unary( + op_name, + shape, + attrs={}, + domain=None, + input_dtype=TensorProto.FLOAT, + output_dtype=TensorProto.FLOAT, + opset=14, +): test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) graph = helper.make_graph( [test_node], "elemwise_test", inputs=[ - helper.make_tensor_value_info("x", dtype, shape), + helper.make_tensor_value_info("x", input_dtype, shape), ], - outputs=[helper.make_tensor_value_info("y", dtype, shape)], + outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], ) model = helper.make_model(graph, producer_name="elemwise_test") - check_correctness(model) + check_correctness(model, opset=opset) -def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None): +def verify_binary( + op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, dtype=TensorProto.FLOAT, opset=14 +): test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) graph = helper.make_graph( [test_node], "binary_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a), - helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b), + helper.make_tensor_value_info("a", dtype, shape_a), + helper.make_tensor_value_info("b", dtype, shape_b), ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c)], + outputs=[helper.make_tensor_value_info("c", dtype, shape_c)], ) model = helper.make_model(graph, producer_name="binary_test") - check_correctness(model) + check_correctness(model, opset=opset) + + +def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14): + a = make_constant_node("a", dtype, [], [4]) + b = make_constant_node("b", dtype, [], [8]) + test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) + graph = helper.make_graph( + [a, b, test_node], + "binary_test", + inputs=[], + outputs=[helper.make_tensor_value_info("c", dtype, ())], + ) + + model = helper.make_model(graph, producer_name="binary_test") + # NOTE: explicitly pass inputs to avoid numerical error + check_correctness(model, opset=opset) def verify_compare(op_name, shape, attrs={}, domain=None): @@ -289,16 +322,95 @@ def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) -def test_add(): - verify_binary("Add", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"]) +def test_binary(op_name: str): + verify_binary(op_name, [1, 32], [1, 32], [1, 32]) + verify_binary_scalar(op_name) + + +@pytest.mark.parametrize("num_inputs", [1, 2, 4]) +@pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) +def test_multi_input(op_name: str, num_inputs: int): + input_shape = [32, 32] + input_var = ["i" + str(i) for i in range(num_inputs)] + input_values = [ + helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for var in input_var + ] + test_node = helper.make_node(op_name, input_var, ["c"]) + graph = helper.make_graph( + [test_node], + "multi_input_test", + inputs=input_values, + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, input_shape)], + ) + + model = helper.make_model(graph, producer_name="multi_input_test") + check_correctness(model) -def test_mul(): - verify_binary("Mul", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["Less", "LessOrEqual", "Greater", "GreaterOrEqual"]) +def test_compare(op_name: str): + verify_compare(op_name, [1, 32]) -def test_sum(): - verify_binary("Sum", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["And", "Or", "Xor"]) +def test_binary_bool(op_name: str): + verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) + + +@pytest.mark.parametrize( + "op_name", + [ + "Sin", + "Cos", + "Tan", + "Sinh", + "Cosh", + "Tanh", + "Asin", + "Acos", + "Atan", + "Asinh", + "Acosh", + "Atanh", + "Neg", + "Abs", + "Log", + "Exp", + "Not", + "Reciprocal", + "Floor", + "Ceil", + "Round", + "IsInf", + "IsNaN", + "Sqrt", + "Relu", + "Elu", + "HardSwish", + "Sign", + "Softplus", + "Softsign", + "Erf", + "Sigmoid", + "Softmax", + "LogSoftmax", + "Identity", + ], +) +def test_unary(op_name: str): + input_dtype = TensorProto.FLOAT + if op_name in [ + "IsNaN", + "IsInf", + ]: + pytest.skip(f"Skipping test {op_name} because current LegalizeOps does not support it.") + elif op_name == "Not": + input_dtype = TensorProto.BOOL + output_dtype = TensorProto.BOOL + else: + output_dtype = TensorProto.FLOAT + verify_unary(op_name, [32, 32], input_dtype=input_dtype, output_dtype=output_dtype) @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -350,6 +462,44 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) +@pytest.mark.parametrize("axis", [0, 1, 2]) +@pytest.mark.parametrize(("name", "opset"), [("Scatter", 10), ("ScatterElements", 11)]) +def test_scatter(axis: int, name: str, opset: int): + if axis != 1: + pytest.skip("The current topi impl is wrong, which only works for axis=1") + input_shape = [16, 16, 16] + indices_shape = [8, 8, 8] + updates_shape = [8, 8, 8] + output_shape = [16, 16, 16] + node = helper.make_node(name, ["data", "indices", "updates"], ["output"], axis=axis) + graph = helper.make_graph( + [node], + "scatter_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="scatter_test") + indices = np.random.randint(0, 16, indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=opset) + + +def test_size(): + test_node = helper.make_node("Size", ["x"], ["y"]) + graph = helper.make_graph( + [test_node], + "size_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 3, 3])], + outputs=[helper.make_tensor_value_info("y", TensorProto.INT64, [3])], + ) + + model = helper.make_model(graph, producer_name="size_test") + check_correctness(model) + + @pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) @pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) @@ -408,18 +558,6 @@ def test_reshape(in_shape, shape, out_shape): check_correctness(model, inputs=input_values) -def test_div(): - verify_binary("Div", [32, 32], [32, 32], [32, 32]) - - -def test_sigmoid(): - verify_unary("Sigmoid", [32, 32]) - - -def test_softmax(): - verify_unary("Softmax", [32, 32, 32]) - - def test_transpose(): verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]}) @@ -567,28 +705,33 @@ def test_shape(): check_correctness(model) -def test_tanh(): - verify_unary("Tanh", [9, 8, 7, 6]) +@pytest.mark.parametrize("upper", [True, False]) +def test_trilu(upper: bool): + verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper}) -def test_sqrt(): - verify_unary("Sqrt", [32, 32]) +def test_selu(): + verify_unary("Selu", [3, 32, 32]) + verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) -def test_relu(): - verify_unary("Relu", [32, 32]) +@pytest.mark.skip(reason="opset 18 is not supported in CI") +def test_mish(): + verify_unary("Mish", [3, 32, 32], opset=18) -def test_tril(): - verify_unary("Trilu", [3, 5, 5], attrs={"upper": False}) +def test_prelu(): + verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32]) -def test_triu(): - verify_unary("Trilu", [3, 5, 5], attrs={"upper": True}) +def test_thresholded_relu(): + verify_unary("ThresholdedRelu", [3, 32, 32]) + verify_unary("ThresholdedRelu", [3, 32, 32], attrs={"alpha": -0.01}) -def test_elu(): - verify_unary("Elu", [32, 32]) +def test_leakyrelu(): + verify_unary("LeakyRelu", [32, 32]) + verify_unary("LeakyRelu", [32, 32], attrs={"alpha": 0.2}) def test_hardsigmoid(): @@ -597,30 +740,40 @@ def test_hardsigmoid(): verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) -def test_hardswish(): - verify_unary("HardSwish", [32, 32]) - - -def test_sign(): - verify_unary("Sign", [32, 32]) - - -def test_not(): - verify_unary("Not", [32, 32], dtype=TensorProto.BOOL) +def test_shrink(): + verify_unary("Shrink", [32, 32]) + verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1}) -def test_conv(): - def _verify_conv(input_shape, weight_shape, output_shape): +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("pad", [0, 2]) +def test_conv(stride: int, dilation: int, pad: int, bias: bool): + def _verify_conv(input_shape, weight_shape): + nd = len(weight_shape) - 2 + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] bias_shape = [output_shape[1]] - conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"]) + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) graph = helper.make_graph( [conv_node], "conv_test", inputs=[ helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), - helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape), - ], + ] + + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], ) @@ -628,20 +781,61 @@ def _verify_conv(input_shape, weight_shape, output_shape): check_correctness(model, atol=1e-4) # Conv1D - _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30]) + _verify_conv([3, 4, 32], [4, 4, 3]) + _verify_conv([3, 4, 32], [2, 4, 3]) # group=2 # Conv2D - _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30]) + _verify_conv([3, 4, 32, 32], [4, 4, 3, 3]) + _verify_conv([3, 4, 32, 32], [2, 4, 3, 3]) # group=2 # Conv3D - _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30]) + _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3]) + _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2 + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("pad", [0, 2]) +def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool): + def _verify_conv_transpose(input_shape, weight_shape): + nd = len(weight_shape) - 2 + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] - 1) * stride - 2 * pad + dilation * (weight_shape[i] - 1) + 1 + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "ConvTranspose", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) + graph = helper.make_graph( + [conv_node], + "conv_transpose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), + ] + + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="conv_transpose_test") + check_correctness(model, atol=1e-4) -def test_pow(): - verify_binary("Pow", [32, 32], [32, 32], [32, 32]) + # ConvTranspose1D + _verify_conv_transpose([3, 4, 32], [4, 4, 3]) + _verify_conv_transpose([3, 4, 32], [4, 2, 3]) # group=2 + # ConvTranspose2D + _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3]) + _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2 -def test_erf(): - verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT) - verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16) +def test_pow(): + verify_binary("Pow", [32, 32], [32, 32], [32, 32]) @pytest.mark.parametrize("reverse", [False]) @@ -712,46 +906,6 @@ def test_const(): check_correctness(model) -def test_sub(): - verify_binary("Sub", [32, 16], [32, 16], [32, 16]) - - -def test_min(): - verify_binary("Min", [32, 16], [32, 16], [32, 16]) - - -def test_max(): - verify_binary("Max", [32, 16], [32, 16], [32, 16]) - - -def test_sin(): - verify_unary("Sin", [32, 16]) - - -def test_cos(): - verify_unary("Cos", [32, 16]) - - -def test_identity(): - verify_unary("Identity", [32, 16]) - - -def test_neg(): - verify_unary("Neg", [32, 16]) - - -def test_abs(): - verify_unary("Abs", [32, 16]) - - -def test_log(): - verify_unary("Log", [32, 16]) - - -def test_exp(): - verify_unary("Exp", [32, 16]) - - def test_instance_norm(): verify_ternary( "InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32], attrs={"epsilon": 1e-12} @@ -761,6 +915,11 @@ def test_instance_norm(): ) +def test_mean_variance_norm(): + verify_unary("MeanVarianceNormalization", [1, 3, 32, 32]) + verify_unary("MeanVarianceNormalization", [1, 3, 32, 32], attrs={"axes": (1, 2, 3)}) + + def test_layer_norm(): layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], ["d"], epsilon=1e-12) @@ -1075,9 +1234,36 @@ def verify_arg_min_max(input_dim, in_dtype, op_name="ArgMax", axis=None, keepdim verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims) +@pytest.mark.parametrize("axis", [-1, 0, 1]) +@pytest.mark.parametrize("largest", [True, False]) +def test_topk(axis: int, largest: int): + in_shape = [32, 32, 32] + k_value = 4 + out_shape = in_shape + out_shape[axis] = k_value + k = make_constant_node("k", TensorProto.INT64, [1], [k_value]) + node = onnx.helper.make_node( + "TopK", + inputs=["data", "k"], + outputs=["values", "indices"], + axis=axis, + largest=largest, + ) + graph = helper.make_graph( + [k, node], + "topk_test", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, in_shape)], + outputs=[ + helper.make_tensor_value_info("values", TensorProto.FLOAT, out_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, out_shape), + ], + ) + model = helper.make_model(graph, producer_name="topk_test") + + check_correctness(model) + + @pytest.mark.parametrize("dynamic", [False, True]) -# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. -@pytest.mark.skip("Produces ill-formed IR") def test_expand(dynamic): if dynamic: # TODO: Support dynamic shape for Expand @@ -1586,14 +1772,6 @@ def test_range(): check_correctness(model) -def test_less(): - verify_compare("Less", [32, 32]) - - -def test_less_equal(): - verify_compare("LessOrEqual", [32, 32]) - - def test_batch_norm(): batch_norm_node = helper.make_node( "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"], epsilon=1e-2 @@ -1811,17 +1989,58 @@ def test_global_average_pool(): verify_unary("GlobalAveragePool", [1, 3, 32, 32, 32]) +def test_global_max_pool(): + verify_unary("GlobalMaxPool", [1, 3, 32]) + verify_unary("GlobalMaxPool", [1, 3, 32, 32]) + verify_unary("GlobalMaxPool", [1, 3, 32, 32, 32]) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_global_lp_pool(p: int): + verify_unary("GlobalLpPool", [1, 3, 32], attrs={"p": p}) + verify_unary("GlobalLpPool", [1, 3, 32, 32], attrs={"p": p}) + verify_unary("GlobalLpPool", [1, 3, 32, 32, 32], attrs={"p": p}) + + +@pytest.mark.parametrize("kernel_shape", [[2, 2], [3, 3]]) +@pytest.mark.parametrize("pads", [None, [1, 1, 1, 1]]) +@pytest.mark.parametrize("strides", [None, [2, 2]]) +def test_maxunpool(kernel_shape, pads, strides): + input_shape = [16, 3, 16, 16] + input_names = ["X", "I"] + input_info = [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("I", TensorProto.INT64, input_shape), + ] + + attrs = {"kernel_shape": kernel_shape} + if pads is not None: + attrs["pads"] = pads + if strides is not None: + attrs["strides"] = strides + + node = helper.make_node("MaxUnpool", inputs=input_names, outputs=["y"], **attrs) + + graph = helper.make_graph( + [node], + "maxunpool_test", + inputs=input_info, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)], + ) + + max_random = int(np.prod(np.array(kernel_shape))) + indices = np.random.randint(0, max_random, size=input_shape) + + model = helper.make_model(graph, producer_name="maxunpool_test") + check_correctness(model, inputs={"I": indices}) + + def test_flatten(): verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0}) verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1}) verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2}) -def test_greater(): - verify_compare("Greater", [32, 32]) - verify_compare("Greater", [64, 16]) - - def test_onehot(): one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["y"], axis=1) graph = helper.make_graph( @@ -1844,8 +2063,189 @@ def test_onehot(): check_correctness(model, inputs=values) -def test_reciprocal(): - verify_unary("Reciprocal", [3, 32, 32]) +@pytest.mark.parametrize("axis", [None, 0, 1, -1]) +@pytest.mark.parametrize("sorted", [0, 1]) +def test_unique(axis: Optional[int], sorted: int): + input_shape = [32, 32] + if axis is None: + output_shape = [-1] + else: + output_shape = [32, 32] + output_shape[axis] = -1 + unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, sorted=sorted) + graph = helper.make_graph( + [unique_node], + "unique_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="unique_test") + check_correctness(model) + + +@pytest.mark.parametrize("mode", ["DCR", "CRD"]) +def test_depth_to_space(mode: Literal["DCR", "CRD"]): + in_shape = [1, 8, 2, 3] + out_shape = [1, 2, 4, 6] + blocksize = 2 + node = onnx.helper.make_node( + "DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blocksize, mode=mode + ) + graph = helper.make_graph( + [node], + "depth_to_space_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, in_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="depth_to_space_test") + + check_correctness(model) + + +def test_space_to_depth(): + in_shape = [1, 2, 4, 6] + out_shape = [1, 8, 2, 3] + blocksize = 2 + node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blocksize) + graph = helper.make_graph( + [node], + "space_to_depth_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, in_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="space_to_depth_test") + + check_correctness(model) + + +def construct_sequence(input_shape: List[int], num_tensors: int, name: str = "sequence"): + inputs = [f"data{i}" for i in range(num_tensors)] + sequence_construct_node = helper.make_node("SequenceConstruct", inputs, [name]) + graph_inputs = [ + helper.make_tensor_value_info(f"data{i}", TensorProto.FLOAT, input_shape) + for i in range(num_tensors) + ] + return sequence_construct_node, graph_inputs + + +def make_constant_node(name: str, data_type: int, dims: List[int], vals: List[int]): + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=data_type, dims=dims, vals=vals), + ) + + +def test_sequence_construct(): + node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) + graph = helper.make_graph( + [node], + "test_sequence_construct", + inputs=graph_inputs, + outputs=[helper.make_tensor_sequence_value_info("sequence", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_construct") + check_correctness(model) + + +def test_sequence_empty(): + sequence_empty_node = helper.make_node("SequenceEmpty", [], ["sequence"]) + graph = helper.make_graph( + [sequence_empty_node], + "test_sequence_empty", + inputs=[], + outputs=[helper.make_tensor_sequence_value_info("sequence", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="test_sequence_empty") + check_correctness(model) + + +@pytest.mark.parametrize("explicit_position", [True, False]) +def test_sequence_erase(explicit_position: bool): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + node_input = ["sequence", "index"] if explicit_position else ["sequence"] + sequence_erase_node = helper.make_node("SequenceErase", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_erase_node], + "test_sequence_erase", + inputs=graph_inputs, + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_erase") + check_correctness(model) + + +@pytest.mark.parametrize("explicit_position", [True, False]) +def test_sequence_insert(explicit_position: bool): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [0]) + node_input = ["sequence", "value", "index"] if explicit_position else ["sequence", "value"] + sequence_insert_node = helper.make_node("SequenceInsert", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_insert_node], + "test_sequence_insert", + inputs=[*graph_inputs, helper.make_tensor_value_info("value", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_insert") + check_correctness(model) + + +@pytest.mark.parametrize("new_axis", [0, 1]) +def test_concat_from_sequence(new_axis: Literal[0, 1]): + if new_axis == 1: + pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet") + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) + concat_from_sequence_node = helper.make_node( + "ConcatFromSequence", ["sequence"], ["output"], axis=1 + ) + graph = helper.make_graph( + [seq_node, concat_from_sequence_node], + "test_concat_from_sequence", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [64, 32])], + ) + model = helper.make_model(graph, producer_name="test_concat_from_sequence") + check_correctness(model) + + +@pytest.mark.parametrize("split", [2, [16, 48]]) +def test_split_to_sequence(split): + split_to_sequence_node = helper.make_node( + "SplitToSequence", + ["data", "split"], + ["output"], + axis=0, + ) + split_shape = [len(split)] if isinstance(split, list) else () + split_node = make_constant_node( + "split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split + ) + graph = helper.make_graph( + [split_node, split_to_sequence_node], + "test_split_to_sequence", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence") + check_correctness(model) + + +def test_sequence_at(): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + node_input = ["sequence", "index"] + sequence_at_node = helper.make_node("SequenceAt", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_at_node], + "test_sequence_at", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_at") + check_correctness(model) def test_symbolic_shape_deduction(): diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index fcb8727d8508..a80b988d06c4 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -60,7 +60,7 @@ def test_unique(exec_mode): result, result_sorted = run_cpu(InputModule, "foo", data, exec_mode=exec_mode) expected_output_sorted, indices = np.unique(data_numpy, return_index=True) - expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + expected_output = [data_numpy.flatten()[index] for index in sorted(indices)] np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy()) np.testing.assert_array_equal(expected_output, result.numpy()) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index d03d48968d90..12436cf8023f 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -204,6 +204,53 @@ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1 tvm.ir.assert_structural_equal(mod, Expected) +def test_conv1d_transpose(): + # fmt: off + @I.ir_module + class Conv1dTranspose: + @R.function + def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16, 3), "float32")): + gv = R.nn.conv1d_transpose(x, w, strides=2, padding=1, dilation=1, output_padding=1, groups=8) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55))) + data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58))) + kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)): + with T.block("data_dilate"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0)) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)): + with T.block("data_pad"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0.0)) + for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)): + with T.block("kernel"): + v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1]) + kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w] + for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(16), T.int64(3)): + with T.block("compute"): + v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c, w_1, dc, dw]) + with T.init(): + compute[v_b, v_c, v_w] = T.float32(0.0) + compute[v_b, v_c, v_w] = compute[v_b, v_c, v_w] + data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_w + v_dw] * kernel[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dw] + + @R.function + def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128, 16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"): + cls = Expected + gv = R.call_tir(cls.conv1d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Conv1dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_conv2d(): # fmt: off @tvm.script.ir_module From 24fd0379270ec3e4ed67e7d0fadd211dc653d639 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 3 Oct 2024 12:29:58 -0700 Subject: [PATCH 189/202] [TVMScript] Enable T.macro decorateing class method (#17435) * [TVMScript] Enable T.macro decorateing class method This PR refactors the implementation of `T.macro`, so that the `self` argument can be passed through the TVMScript parser. Then we can decroate the class methods with `T.macro`. * update test --- python/tvm/script/parser/core/parser.py | 4 +- python/tvm/script/parser/relax/entry.py | 7 +++- python/tvm/script/parser/tir/entry.py | 7 +++- .../tvmscript/test_tvmscript_parser_tir.py | 42 +++++++++++++++++-- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 372a3c54e4c5..f40b9a7cf6d3 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -135,9 +135,9 @@ def _find_parser_def(self): def get_macro_def(self): ast_module = self.source.as_ast() for decl in ast_module.body: - if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__: + if isinstance(decl, doc.FunctionDef) and decl.name == self.func.__name__: return decl - raise RuntimeError(f"cannot find macro definition for {self.__name__}") + raise RuntimeError(f"cannot find macro definition for {self.func.__name__}") def __call__(self, *args, **kwargs): param_binding = inspect.signature(self.func).bind(*args, **kwargs) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 73a5d7149a81..04a5f985643e 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -128,8 +128,11 @@ def macro(*args, hygienic: bool = True) -> _Callable: def _decorator(func: _Callable) -> ScriptMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = RelaxMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 79eb88dfc102..c7d5dc756b32 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -139,8 +139,11 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: def _decorator(func: Callable) -> TIRMacro: source, closure_vars = scan_macro(func, utils.inspect_function_capture(func)) obj = TIRMacro(source, closure_vars, func, hygienic) - obj.__name__ = func.__name__ - return obj + + def wrapper(*args, **kwargs): + return obj(*args, **kwargs) + + return wrapper if len(args) == 0: return _decorator diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 2dcbc89d47a6..16b206751402 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -116,8 +116,6 @@ def evaluate0(): def func1(): T.evaluate(0) - assert func1.hygienic - @T.prim_func(private=True) def use1(): func1() @@ -129,8 +127,6 @@ def use1(): def func2(): T.evaluate(0) - assert func2.hygienic - @T.prim_func(private=True) def use2(): func2() @@ -212,6 +208,44 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32" tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) +def test_tir_macro_in_class(): + class Object: + def __init__(self, x: T.Buffer): + self.local_x = T.alloc_buffer(x.shape, x.dtype) + + @T.macro + def load(self, x: T.Buffer): + N, M = T.meta_var(self.local_x.shape) + for i, j in T.grid(N, M): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + self.local_x[vi, vj] = x[vi, vj] + + @T.prim_func(private=True) + def func_w_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + o1 = T.meta_var(Object(A)) + o1.load(A) + o2 = T.meta_var(Object(A)) + o2.load(o1.local_x) + + @T.prim_func(private=True) + def func_no_macro(a: T.handle): + A = T.match_buffer(a, [128, 128]) + local_a = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_a[vi, vj] = A[vi, vj] + local_b = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i, j]) + local_b[vi, vj] = local_a[vi, vj] + + tvm.ir.assert_structural_equal(func_no_macro, func_w_macro) + + def test_tir_starred_expression(): dims = (128, 128) From ba0881ef24d17a11d7a46e4d662cb4b1632a652c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81goston=20Czobor?= <73029973+agoston-mc@users.noreply.github.com> Date: Fri, 4 Oct 2024 07:55:26 +0200 Subject: [PATCH 190/202] [Docker][CI] Add NNEF dependency to CI images (#17433) [Docker][CI] Add NNEF dependency --- docker/Dockerfile.ci_arm | 4 ++++ docker/Dockerfile.ci_cortexm | 4 ++++ docker/Dockerfile.ci_cpu | 4 ++++ docker/Dockerfile.ci_gpu | 3 +++ docker/Dockerfile.ci_hexagon | 4 ++++ docker/Dockerfile.ci_riscv | 4 ++++ docker/install/ubuntu_install_nnef.sh | 25 +++++++++++++++++++++++++ docker/python/ci-constraints.txt | 2 ++ 8 files changed, 50 insertions(+) create mode 100644 docker/install/ubuntu_install_nnef.sh diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 2be887079e34..16ffecb315e9 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -75,6 +75,10 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh diff --git a/docker/Dockerfile.ci_cortexm b/docker/Dockerfile.ci_cortexm index 8006b27e84c2..5535d29ed104 100644 --- a/docker/Dockerfile.ci_cortexm +++ b/docker/Dockerfile.ci_cortexm @@ -108,6 +108,10 @@ RUN bash /install/ubuntu_install_arduino.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # Install CMSIS_NN COPY install/ubuntu_install_cmsis.sh /install/ubuntu_install_cmsis.sh RUN bash /install/ubuntu_install_cmsis.sh /opt/arm/ethosu/cmsis diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 37c7c9085714..9e53882e1638 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -134,6 +134,10 @@ RUN bash /install/ubuntu_install_libxsmm.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # AArch64 Architecture Envelope Model (AEM) COPY install/ubuntu_install_aprofile_aem.sh /install RUN bash /install/ubuntu_install_aprofile_aem.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 1a5721c549ab..7f5a68911c6a 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -104,6 +104,9 @@ RUN bash /install/ubuntu_install_libtorch.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + COPY install/ubuntu_install_dgl.sh /install/ubuntu_install_dgl.sh RUN bash /install/ubuntu_install_dgl.sh diff --git a/docker/Dockerfile.ci_hexagon b/docker/Dockerfile.ci_hexagon index 11b3041f3c56..489894d252ae 100644 --- a/docker/Dockerfile.ci_hexagon +++ b/docker/Dockerfile.ci_hexagon @@ -84,6 +84,10 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # xgboost (for tuning) COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh diff --git a/docker/Dockerfile.ci_riscv b/docker/Dockerfile.ci_riscv index d1b5a033b6e7..c26470985a92 100644 --- a/docker/Dockerfile.ci_riscv +++ b/docker/Dockerfile.ci_riscv @@ -75,6 +75,10 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # sccache COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh RUN bash /install/ubuntu_install_sccache.sh diff --git a/docker/install/ubuntu_install_nnef.sh b/docker/install/ubuntu_install_nnef.sh new file mode 100644 index 000000000000..6cd4761787c5 --- /dev/null +++ b/docker/install/ubuntu_install_nnef.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +pip3 install \ + nnef_tools==1.0.6 \ + nnef==1.0.7 diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index 003c13170411..feba27cd03d0 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -37,3 +37,5 @@ tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" #xgboost = "==1.4.2" +nnef = "==1.0.7" +nnef_tools = "==1.0.6" From accd582d3a006b6c3473187e1c155fa535343d8a Mon Sep 17 00:00:00 2001 From: Yongqi Date: Sat, 5 Oct 2024 15:32:31 +0800 Subject: [PATCH 191/202] =?UTF-8?q?[BugFix][TIR][Schedule]=20TileWithTenso?= =?UTF-8?q?rIntrin=20skip=20ComputeInline=20if=20bu=E2=80=A6=20(#17440)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [BugFix][TIR][Schedule] TileWithTensorIntrin skip ComputeInline if buffer not padded by PadEinsum --- src/tir/schedule/transform.cc | 63 +++- ...test_meta_schedule_schedule_rule_mlt_tc.py | 295 ++++++++++++++++++ 2 files changed, 346 insertions(+), 12 deletions(-) diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index fec214fa1fc7..c644fbecdf5c 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -326,23 +326,62 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block if (!opt_tensorize_info) return NullOpt; const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); if (info->block_iter_paddings.defined()) { + // We have to track whether each producer or consumer is padded. + // To do so, we first record all the Block's. + std::unordered_set original_producers, original_consumers; + { + for (const auto& p : GetProducers(sch->state(), sch->GetSRef(block_rv))) + original_producers.insert(p.get()); + for (const auto& c : GetConsumers(sch->state(), sch->GetSRef(block_rv))) + original_consumers.insert(c.get()); + } + + // Pad. Maybe we can make PadEinsum return the changes it made, to avoid bookkeeping? sch->PadEinsum(block_rv, info->block_iter_paddings.value()); + + // Now we need to find out all the padded Block's. + Array inlined_producers, inlined_consumers; + for (const auto& producer : sch->GetProducers(block_rv)) { + // PadEinsum will not modify the producer if it does not need padding. + if (original_producers.count(sch->GetSRef(producer).get())) { + // Producer not padded. No inlining. + continue; + } + auto the_original_producers = sch->GetProducers(producer); + if (the_original_producers.empty()) { + // The original producer is input. + continue; + } + ICHECK_EQ(the_original_producers.size(), 1u); + auto the_original_producer = the_original_producers[0]; + ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get())); + inlined_producers.push_back(the_original_producer); + } + for (const auto& consumer : sch->GetConsumers(block_rv)) { + // PadEinsum will not modify the consumer if it does not need padding. + if (original_consumers.count(sch->GetSRef(consumer).get())) { + // Consumer not padded. No inlining. + continue; + } + auto the_original_consumers = sch->GetConsumers(consumer); + if (the_original_consumers.empty()) { + // The original consumer is output. + continue; + } + ICHECK_EQ(the_original_consumers.size(), 1u); + auto the_original_consumer = the_original_consumers[0]; + ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get())); + inlined_consumers.push_back(consumer); + } + // Inline the producer and consumer padding blocks - auto producers = sch->GetProducers(block_rv); - for (const auto& producer : producers) { - auto original_producers = sch->GetProducers(producer); - // NOTICE: there may not all producers padded. + for (const auto& the_original_producer : inlined_producers) { // Inline the original producer into the padding block. This ensures that the new producer // has the padded shape. - if (original_producers.size() == 1u) { - sch->ComputeInline(original_producers[0]); - } + sch->ComputeInline(the_original_producer); } - auto consumers = sch->GetConsumers(block_rv); - for (const auto& consumer : consumers) { - auto sref = sch->GetSRef(consumer); - if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) - sch->ComputeInline(consumer); + for (const auto& consumer : inlined_consumers) { + sch->ComputeInline(consumer); } } // Construct a mapping from tir loops back to LoopRVs diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index 1fd2ab84749e..be936e6e84fb 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -1207,5 +1207,300 @@ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buf ) +def test_padded_matmul_single_padded_input(): + # fmt: off + @T.prim_func + def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="shared") + C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(32, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(32): + for ax0_ax1_fused in range(65536): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in range(8): + for ax0_0, ax1_0 in T.grid(8, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 4, 2): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 8 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(8): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_pad_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_fused // 2) + v1_o = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2) + v2_o = T.axis.spatial(8, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_pad_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_pad_shared"): + v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512) + v2 = T.axis.spatial(8, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 1023) + T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32] = C_reindex_pad_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [1, 2, 4, 2, 4]), + ("SamplePerfectTile", [1, 16, 2, 1, 2]), + ("SamplePerfectTile", [32, 8, 1]), + ("SampleCategorical", 3), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1023, + m=1024, + k=4096, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_single_padded_input_0], + expected_decisions=[decision_0], + ) + + +def test_padded_matmul_no_padded_output(): + # fmt: off + @T.prim_func + def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_pad_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(64, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_0_0 in range(128): + for ax0_ax1_fused in range(4096): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(2048): + with T.block("B_reindex_pad_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0)) + for ax2_0_1 in range(2): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 4): + with T.block("B_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0) + T.reads(B_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 1, 4): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused) + v1_o = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2_o = T.axis.spatial(2, ax2 + ax2_1) + v3_o = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = C_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [4, 2, 4, 2, 1]), + ("SamplePerfectTile", [16, 1, 1, 1, 4]), + ("SamplePerfectTile", [128, 2, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1024, + m=1024, + k=4095, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_no_padded_output_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": tvm.testing.main() From ff0b07ba6f225128fb030ebb0f45704d44812f00 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 6 Oct 2024 21:54:13 +0800 Subject: [PATCH 192/202] [TIR] Add `is_vector` Method to DataType class and update usages across Codebase (#17443) * Refactor data_type.h and c_runtime_api.h This commit refactors the `data_type.h` and `c_runtime_api.h` files. It introduces a new function `is_vector()` in the `DataType` class to check if a type is a vector type. Additionally, it adds a new constant `kTVMGridConstant` in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code organization and provide better support for vector types. * revert kTVMGridConstant * lint fix --- include/tvm/runtime/data_type.h | 2 ++ include/tvm/topi/elemwise.h | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/intrin_rule_hexagon.cc | 8 ++++---- src/tir/analysis/verify_gpu_code.cc | 8 ++++---- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a330ccbbdf65..c49fde1746bc 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -148,6 +148,8 @@ class DataType { bool is_fixed_length_vector() const { return static_cast(data_.lanes) > 1; } /*! \return Whether the type is a scalable vector. */ bool is_scalable_vector() const { return static_cast(data_.lanes) < -1; } + /*! \return whether type is a vector type. */ + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; } /*! \return whether type is a Void type. */ diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 132992c57dc7..806ddcb662f9 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + } else if (expr.dtype().lanes() == 1 && type.is_vector()) { return tvm::tir::Broadcast(expr, type.lanes()); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index e21436e556ee..3d6d3a9461d3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper( if (const RampNode* ramp = last_index.as()) { PrimExpr offset = ramp->base + (ramp->stride * i); last_index_value = MakeValue(offset); - } else if (last_index.dtype().lanes() > 1) { + } else if (last_index.dtype().is_vector()) { if (i == 0) { cached_vector_index = MakeValue(last_index); } diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index 7c4b38c1d702..2661f2fa6591 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr( } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid") const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f012f8a1b35e..8eda537579e7 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtype.lanes() > 1) { + if (op->value->dtype.is_vector()) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; From ba80646639d863a07e360dc377d592d1469efb73 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 7 Oct 2024 21:38:44 +0800 Subject: [PATCH 193/202] [ONNX] Move relax related tests to the correct file (#17447) There are a few relax tests in `tests/python/frontend/onnx/test_forward.py`, which is used for relay frontend. This commit moves them to the correct file. --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 10 +-- tests/python/frontend/onnx/test_forward.py | 62 ------------------- tests/python/relax/test_frontend_onnx.py | 43 +++++++++++++ 3 files changed, 49 insertions(+), 66 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5777f51fe296..36a7823f8655 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -740,10 +740,12 @@ def _impl_v14(cls, bb, inputs, attr, params): x = inputs[0] k = inputs[1] if len(inputs) > 1 else 0 - if isinstance(k, relax.Var) and k.name_hint in params: - k = get_constant(k, params) - elif isinstance(k, relax.Constant): - k = int(k.data.numpy()[0]) + if len(inputs) > 1: + k = get_constant(inputs[1], params) + if isinstance(k, relax.Constant): + k = int(k.data.numpy()[0]) + else: + raise ValueError("Currently only support constant k for Trilu op.") else: k = 0 diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a5811d0dbd46..a81352bb679f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -37,7 +37,6 @@ from tvm.contrib import graph_executor, utils from tvm.relay.frontend.common import infer_type from tvm.relay.build_module import bind_params_by_name -from tvm.relax.frontend.onnx import from_onnx from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span import onnx @@ -5441,67 +5440,6 @@ def verify_softplus(indata): verify_softplus(input_data) -def test_load_cumsum(): - """test_load_cumsum""" - - def create_cumsum_model(): - input_shape = [2, 3] - - graph = helper.make_graph( - [ - helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]), - ], - "cumsum_graph", - inputs=[ - helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape), - helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"), - ], - outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - from_onnx(create_cumsum_model()) - - -def test_load_trilu(): - """test_load_trilu""" - - def create_trilu_model(): - input_shape = [2, 3, 3] - - graph = helper.make_graph( - [ - helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), - ], - "trilu_graph", - inputs=[ - helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), - helper.make_tensor_value_info("k", onnx.TensorProto.INT32, [1], "k"), - ], - outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - def create_trilu_model_const_k(): - input_shape = [2, 3, 3] - - graph = helper.make_graph( - [ - make_constant_node("k", onnx.TensorProto.INT32, [1], [1]), - helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), - ], - "trilu_graph", - inputs=[ - helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), - ], - outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - from_onnx(create_trilu_model()) - from_onnx(create_trilu_model_const_k()) - - @tvm.testing.parametrize_targets def test_cumsum(target, dev): """test_cumsum""" diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 2837ad2185e9..f2bbd3f3f585 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -710,6 +710,28 @@ def test_trilu(upper: bool): verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper}) +@pytest.mark.parametrize("k_value", [-1, 0, 1]) +def test_trilu_with_const_k(k_value: int): + """test_trilu_with_const_k""" + + input_shape = [2, 3, 3] + + graph = helper.make_graph( + [ + make_constant_node("k", onnx.TensorProto.INT64, [1], [k_value]), + helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), + ], + "trilu_graph", + inputs=[ + helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), + ], + outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], + ) + + model = helper.make_model(graph, producer_name="trilu_graph") + check_correctness(model) + + def test_selu(): verify_unary("Selu", [3, 32, 32]) verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) @@ -859,6 +881,27 @@ def test_cumsum(reverse, exclusive): check_correctness(model) +def test_cumsum1(): + """test_cumsum1""" + + input_shape = [2, 3] + + graph = helper.make_graph( + [ + helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]), + ], + "cumsum_graph", + inputs=[ + helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape), + helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"), + ], + outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)], + ) + + model = helper.make_model(graph, producer_name="cumsum_graph") + check_correctness(model) + + @pytest.mark.parametrize("axis", [[0, 2], None]) def test_squeeze(axis): if axis: From a5d04a5e89e55f5152e7716601c1f354d5d22b8f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 7 Oct 2024 23:18:08 +0900 Subject: [PATCH 194/202] [CI][Docs] Upgrade Sphinx (#17444) * upgrade sphinx * try latest version of sphinx * install tlcpack-sphinx-addon --- docker/install/ubuntu_install_sphinx.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 96023fa6e633..bbaf04976691 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -20,14 +20,14 @@ set -e set -u set -o pipefail -# NOTE: install docutils < 0.17 to work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 pip3 install \ autodocsumm \ - "commonmark>=0.7.3" \ - "docutils>=0.11,<0.17" \ + commonmark \ + docutils \ Image \ matplotlib \ - sphinx==4.2.0 \ + sphinx \ sphinx_autodoc_annotation \ - "git+https://github.com/sphinx-gallery/sphinx-gallery.git@6142f1791151849b5bec4bf3959f75697ba226cd" \ - sphinx_rtd_theme + sphinx-gallery \ + sphinx_rtd_theme \ + https://github.com/tlc-pack/tlcpack-sphinx-addon/archive/refs/tags/v0.2.3.zip From abb901f08cdc646d69758eb32503dcab59a904e0 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 7 Oct 2024 22:56:54 +0800 Subject: [PATCH 195/202] [Relax] Support left_shift and right_shift op (#17448) Introduced left_shift and right_shift op in Relax with ONNX frontend support. --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 104 ++++++++++++++++-- python/tvm/relax/op/__init__.py | 2 + python/tvm/relax/op/binary.py | 32 ++++++ .../relax/transform/legalize_ops/binary.py | 2 + python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/distributed/binary.cc | 2 + src/relax/op/tensor/binary.cc | 2 + src/relax/op/tensor/binary.h | 6 + tests/python/relax/test_frontend_onnx.py | 36 ++++++ tests/python/relax/test_op_binary.py | 2 + 10 files changed, 184 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 36a7823f8655..aa156a025fef 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter): relax_op: Callable = None @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def base_impl(cls, bb, inputs, attr, params): + """Base implementation for binary operations.""" if cls.numpy_op is None or cls.relax_op is None: raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") if all([isinstance(inp, relax.Constant) for inp in inputs]): @@ -274,6 +275,10 @@ class Add(BinaryBase): numpy_op = _np.add relax_op = relax.op.add + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Sub(BinaryBase): """Converts an onnx Sub node into an equivalent Relax expression.""" @@ -281,6 +286,10 @@ class Sub(BinaryBase): numpy_op = _np.subtract relax_op = relax.op.subtract + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Mul(BinaryBase): """Converts an onnx Mul node into an equivalent Relax expression.""" @@ -288,6 +297,10 @@ class Mul(BinaryBase): numpy_op = _np.multiply relax_op = relax.op.multiply + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Div(BinaryBase): """Converts an onnx Div node into an equivalent Relax expression.""" @@ -295,6 +308,10 @@ class Div(BinaryBase): numpy_op = _np.divide relax_op = relax.op.divide + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Pow(BinaryBase): """Converts an onnx Pow node into an equivalent Relax expression.""" @@ -302,6 +319,10 @@ class Pow(BinaryBase): numpy_op = _np.power relax_op = relax.op.power + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class And(BinaryBase): """Converts an onnx And node into an equivalent Relax expression.""" @@ -309,6 +330,10 @@ class And(BinaryBase): numpy_op = _np.logical_and relax_op = relax.op.logical_and + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Or(BinaryBase): """Converts an onnx Or node into an equivalent Relax expression.""" @@ -316,6 +341,10 @@ class Or(BinaryBase): numpy_op = _np.logical_or relax_op = relax.op.logical_or + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Xor(BinaryBase): """Converts an onnx Xor node into an equivalent Relax expression.""" @@ -323,6 +352,10 @@ class Xor(BinaryBase): numpy_op = _np.logical_xor relax_op = relax.op.logical_xor + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Less(BinaryBase): """Converts an onnx Less node into an equivalent Relax expression.""" @@ -330,6 +363,10 @@ class Less(BinaryBase): numpy_op = _np.less relax_op = relax.op.less + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class LessOrEqual(BinaryBase): """Converts an onnx LessEqual node into an equivalent Relax expression.""" @@ -337,6 +374,10 @@ class LessOrEqual(BinaryBase): numpy_op = _np.less_equal relax_op = relax.op.less_equal + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Greater(BinaryBase): """Converts an onnx Greater node into an equivalent Relax expression.""" @@ -344,6 +385,10 @@ class Greater(BinaryBase): numpy_op = _np.greater relax_op = relax.op.greater + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class GreaterOrEqual(BinaryBase): """Converts an onnx GreaterEqual node into an equivalent Relax expression.""" @@ -351,6 +396,10 @@ class GreaterOrEqual(BinaryBase): numpy_op = _np.greater_equal relax_op = relax.op.greater_equal + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Equal(OnnxOpConverter): """Converts an onnx Equal node into an equivalent Relax expression.""" @@ -374,7 +423,8 @@ class BitwiseBase(BinaryBase): """Converts an onnx BitwiseBase node into an equivalent Relax expression.""" @classmethod - def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): + def base_impl(cls, bb, inputs, attr, params): + """Base implementation for bitwise operations.""" valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] for num, inp in enumerate(inputs): if inp.struct_info.dtype not in valid_types: @@ -382,31 +432,69 @@ def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): f"Bitwise operations expect all inputs to have integer types, " f"got {inp.struct_info.dtype} for input {num}" ) - return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op) + return super().base_impl(bb, inputs, attr, params) class BitwiseAnd(BitwiseBase): """Converts an onnx BitwiseAnd node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_and + relax_op = relax.op.bitwise_and + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and) + return cls.base_impl(bb, inputs, attr, params) class BitwiseOr(BitwiseBase): """Converts an onnx BitwiseOr node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_or + relax_op = relax.op.bitwise_or + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or) + return cls.base_impl(bb, inputs, attr, params) class BitwiseXor(BitwiseBase): """Converts an onnx BitwiseXor node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_xor + relax_op = relax.op.bitwise_xor + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor) + return cls.base_impl(bb, inputs, attr, params) + + +class BitwiseNot(BitwiseBase): + """Converts an onnx BitwiseNot node into an equivalent Relax expression.""" + + numpy_op = _np.bitwise_not + relax_op = relax.op.bitwise_not + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + + +class BitShift(BitwiseBase): + """Converts an onnx BitShift node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + direction = attr.get("direction", "LEFT").decode("ascii") + if direction == "LEFT": + cls.numpy_op = _np.left_shift + cls.relax_op = relax.op.left_shift + elif direction == "RIGHT": + cls.numpy_op = _np.right_shift + cls.relax_op = relax.op.right_shift + else: + raise ValueError("Unsupported Shift Direction: " + direction) + + return cls.base_impl(bb, inputs, attr, params) class Sigmoid(OnnxOpConverter): @@ -2654,8 +2742,8 @@ def _get_convert_map(): "BitwiseAnd": BitwiseAnd, "BitwiseOr": BitwiseOr, "BitwiseXor": BitwiseXor, - # "BitwiseNot": BitwiseNot, - # "BitwiseShift": BitwiseShift, + "BitwiseNot": BitwiseNot, + "BitShift": BitShift, "And": And, "Or": Or, "Xor": Xor, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4581defa1a77..c99201e969b5 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -52,6 +52,7 @@ floor_divide, greater, greater_equal, + left_shift, less, less_equal, logical_and, @@ -62,6 +63,7 @@ multiply, not_equal, power, + right_shift, subtract, ) from .create import ( diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 982b3a24f26c..7632235cb32c 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr: The computed result. """ return _ffi_api.bitwise_xor(x1, x2) + + +def left_shift(x1: Expr, x2: Expr) -> Expr: + """Bitwise Shift Left + Parameters + ---------- + x1 : relax.Expr + The input tensor to be shifted. + x2 : relax.Expr + The number of positions to shift. + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.left_shift(x1, x2) + + +def right_shift(x1: Expr, x2: Expr) -> Expr: + """Bitwise Shift Right + Parameters + ---------- + x1 : relax.Expr + The input tensor to be shifted. + x2 : relax.Expr + The number of positions to shift. + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.right_shift(x1, x2) diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 16d6c0269616..d28e100edb9f 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.bitwise_and", _binary(topi.bitwise_and)) register_legalize("relax.bitwise_or", _binary(topi.bitwise_or)) register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor)) +register_legalize("relax.left_shift", _binary(topi.left_shift)) +register_legalize("relax.right_shift", _binary(topi.right_shift)) # logical register_legalize("relax.logical_and", _binary(topi.logical_and)) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index c4be8afac4d2..e6ff35ebe56b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -102,6 +102,7 @@ isinf, isnan, layout_transform, + left_shift, less, less_equal, linear, @@ -133,6 +134,7 @@ quantize, repeat, reshape, + right_shift, round, rsqrt, scatter_elements, @@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isinf", "isnan", "layout_transform", + "left_shift", "less", "less_equal", "linear", @@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "repeat", "reshape", "rewriter", + "right_shift", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 63f4f356c03d..6ad71e0f85bf 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift); } // namespace distributed } // namespace relax diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index afc0fb73031b..f1dc3d4904c8 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index b28a6c33690b..003bcb7e27cf 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2); /*! \brief Broadcasted element-wise bitwise xor */ Expr bitwise_xor(Expr x1, Expr x2); +/*! \brief Broadcasted element-wise bitwise shift left */ +Expr left_shift(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise bitwise shift right */ +Expr right_shift(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index f2bbd3f3f585..e3ed3a3a9d4d 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -358,6 +358,42 @@ def test_binary_bool(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) +@pytest.mark.skip(reason="opset 18 is not supported in CI") +@pytest.mark.parametrize("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"]) +def test_bitwise(op_name: str): + verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.UINT64, opset=18) + + +@pytest.mark.skip(reason="opset 18 is not supported in CI") +def test_bitwise_not(): + verify_unary( + "BitwiseNot", + [32, 32], + input_dtype=TensorProto.UINT64, + output_dtype=TensorProto.UINT64, + opset=18, + ) + + +@pytest.mark.parametrize("direction", ["LEFT", "RIGHT"]) +def test_bitwise_shift(direction: str): + shape = [32, 32] + dtype = TensorProto.UINT64 + test_node = helper.make_node("BitShift", ["a", "b"], ["c"], direction=direction) + graph = helper.make_graph( + [test_node], + "binary_test", + inputs=[ + helper.make_tensor_value_info("a", dtype, shape), + helper.make_tensor_value_info("b", dtype, shape), + ], + outputs=[helper.make_tensor_value_info("c", dtype, shape)], + ) + + model = helper.make_model(graph, producer_name="binary_test") + check_correctness(model, inputs={"b": np.random.randint(0, 8, shape).astype("uint64")}) + + @pytest.mark.parametrize( "op_name", [ diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 85842f1578df..20c111495d6a 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -46,6 +46,8 @@ def test_op_correctness(): assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and") assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or") assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor") + assert relax.op.left_shift(x, y).op == Op.get("relax.left_shift") + assert relax.op.right_shift(x, y).op == Op.get("relax.right_shift") x = relax.Var("x", R.Tensor((2, 3), "bool")) y = relax.Var("y", R.Tensor((2, 3), "bool")) From 001d5ec90c2821b16f9d4edd913dfeff03c027a3 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 8 Oct 2024 09:57:27 +0900 Subject: [PATCH 196/202] [Relax][PyTorch][Docs] Use `torch.export` insteamd of `fx.symbolic_trace` for tutorial (#17436) * use torch.export * in order to make interface consistent, user inputs should be placed first * chore --- docs/get_started/tutorials/ir_module.py | 15 ++-- docs/how_to/tutorials/e2e_opt_model.py | 18 +++-- .../torch/exported_program_translator.py | 71 ++++++++++--------- .../test_frontend_from_exported_program.py | 4 +- 4 files changed, 56 insertions(+), 52 deletions(-) diff --git a/docs/get_started/tutorials/ir_module.py b/docs/get_started/tutorials/ir_module.py index f813333bafc3..0a825c3da757 100644 --- a/docs/get_started/tutorials/ir_module.py +++ b/docs/get_started/tutorials/ir_module.py @@ -40,8 +40,9 @@ # below. import torch -from torch import fx, nn -from tvm.relax.frontend.torch import from_fx +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program ###################################################################### # Import from existing models @@ -67,13 +68,15 @@ def forward(self, x): return x -# Give the input shape and data type -input_info = [((1, 784), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 784, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(TorchModel()) - mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(TorchModel().eval(), example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True + ) mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch) # Print the IRModule diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 5c11439e1635..532fb89fd3bc 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -34,10 +34,10 @@ import os import numpy as np import torch -from torch import fx +from torch.export import export from torchvision.models.resnet import ResNet18_Weights, resnet18 -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval() ###################################################################### # Review Overall Flow @@ -63,21 +63,19 @@ # Convert the model to IRModule # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further -# optimization. Besides the model, we also need to provide the input shape and data type. +# optimization. import tvm from tvm import relax -from tvm.relax.frontend.torch import from_fx +from tvm.relax.frontend.torch import from_exported_program -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) - -# Give the input shape and data type -input_info = [((1, 3, 224, 224), "float32")] +# Give an example argument to torch.export +example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),) # Convert the model to IRModule with torch.no_grad(): - torch_fx_model = fx.symbolic_trace(torch_model) - mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) + exported_program = export(torch_model, example_args) + mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = relax.frontend.detach_params(mod) mod.show() diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1401a0bcef3a..7bcd20c462bd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter): from torch import fx - def create_input_vars( - self, exported_program: torch.export.ExportedProgram - ) -> Tuple[List[relax.Var], List[relax.Var]]: - """Create relax input vars.""" - parameters_buffers_constants = [] - user_inputs = [] - for spec in exported_program.graph_signature.input_specs: - name_hint = spec.arg.name - if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: - shape = exported_program.tensor_constants[spec.target].shape - torch_dtype = exported_program.tensor_constants[spec.target].dtype - elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): - if node.name == name_hint: - shape = node.meta["tensor_meta"].shape - torch_dtype = node.meta["tensor_meta"].dtype - break - else: - # PARAMETER or BUFFER - shape = exported_program.state_dict[spec.target].shape - torch_dtype = exported_program.state_dict[spec.target].dtype - - dtype = self._convert_data_type(torch_dtype) - relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - user_inputs.append(relax_var) - else: - parameters_buffers_constants.append(relax_var) - - return parameters_buffers_constants, user_inputs - ########## Unary Ops ########## def _hardtanh(self, node: fx.Node) -> relax.Expr: @@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + ########## Others ########## + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -293,6 +264,37 @@ def create_convert_map( "getitem": self._getitem, } + def create_input_vars( + self, exported_program: torch.export.ExportedProgram + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + """Create relax input vars.""" + parameters_buffers_constants = OrderedDict() + user_inputs = OrderedDict() + for spec in exported_program.graph_signature.input_specs: + name_hint = spec.arg.name + if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: + shape = exported_program.tensor_constants[spec.target].shape + torch_dtype = exported_program.tensor_constants[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): + if node.name == name_hint: + shape = node.meta["tensor_meta"].shape + torch_dtype = node.meta["tensor_meta"].dtype + break + else: + # PARAMETER or BUFFER + shape = exported_program.state_dict[spec.target].shape + torch_dtype = exported_program.state_dict[spec.target].dtype + + dtype = self._convert_data_type(torch_dtype) + relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + user_inputs[name_hint] = relax_var + else: + parameters_buffers_constants[name_hint] = relax_var + + return parameters_buffers_constants, user_inputs + def from_exported_program( self, exported_program: torch.export.ExportedProgram, @@ -305,7 +307,8 @@ def from_exported_program( # Create input variables. parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) - inputs_vars = parameter_buffer_constant_vars + user_input_vars + inputs_vars = user_input_vars.copy() + inputs_vars.update(parameter_buffer_constant_vars) # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() @@ -314,7 +317,7 @@ def from_exported_program( nodes: List[fx.Node] = exported_program.graph.nodes with self.block_builder.function( - name=func_name, params=inputs_vars.copy(), attrs=func_attrs + name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs ): output = None with self.block_builder.dataflow(): @@ -325,7 +328,7 @@ def from_exported_program( # Ignore sym input continue - self.env[node] = inputs_vars.pop(0) + self.env[node] = inputs_vars[node.name] elif node.op == "output": args = self.retrieve_args(node) assert len(args) == 1 diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 65890ff6971b..0d8425fc7f30 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3550,9 +3550,9 @@ def forward(self, input): class expected1: @R.function def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), conv_bias: R.Tensor((6,), dtype="float32"), - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): R.func_attr({"num_input": 1}) # block 0 @@ -3586,7 +3586,7 @@ def main( params = params["main"] assert len(params) == len(func.params) - 1 - for param_var, param_ndarray in zip(func.params[:-1], params): + for param_var, param_ndarray in zip(func.params[1:], params): assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape assert param_var.struct_info.dtype == param_ndarray.dtype From eef234060d12f59fa07fff15bebcdbd6a772d594 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 9 Oct 2024 01:23:04 +0800 Subject: [PATCH 197/202] [Community] update contributors (#17450) Update recent nominations about contributors --- CONTRIBUTORS.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 35deb7def799..d9a0082e0f1f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ We do encourage everyone to work anything they are interested in. - [Siyuan Feng](https://github.com/Hzfengsy) (PMC): @Hzfengsy - tir - [Josh Fromm](https://github.com/jwfromm) (PMC): @jwfromm - frontends, quantization, topi - [Mehrdad Hessar](https://github.com/mehrdadh): @mehrdadh - microTVM, hexagon +- [Masahiro Hiramori](https://github.com/mshr-h): @mshr-h - relax, frontend - [Bohan Hou](https://github.com/spectrometerHBH) (PMC): @spectrometerHBH - tir, arith, tvm-script - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends - [Luke Hutton](https://github.com/lhutton1): @lhutton1 - ethos-u, arm @@ -80,6 +81,7 @@ We do encourage everyone to work anything they are interested in. - [Chris Sullivan](https://github.com/csullivan): @csullivan - amd backend - [Siva Rama Krishna Reddy](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web +- [Tong Meng](https://github.com/Archermmt): @Archermmt - msc - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime - [Gavin Uberti](https://github.com/guberti): @guberti - microtvm, arm - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - vta, chisel @@ -90,7 +92,7 @@ We do encourage everyone to work anything they are interested in. - [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi - [Zihao Ye](https://github.com/yzh119): @yzh119 - tir - [Hao Yu](https://github.com/comaniac): @comaniac (PMC) - relay, byoc, auto_scheduler -- [Shuai Yuan](https://github.com/ysh329): @ysh329 - ci +- [Shuai Yuan](https://github.com/ysh329): @ysh329 (PMC) - ci - [Qiang Zhang](https://github.com/Johnson9009): @Johnson9009 - relay, tvm-script - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, auto_scheduler, topi, relay - [Xiyou Zhou](https://github.com/zxybazh): @zxybazh - relay @@ -123,6 +125,7 @@ We do encourage everyone to work anything they are interested in. - [Sergei Grechanik](https://github.com/sgrechanik-h): @sgrechanik-h - [Altan Haan](https://github.com/altanh): @altanh - [Mehrdad Hessar](https://github.com/mehrdadh): @mehrdadh +- [Masahiro Hiramori](https://github.com/mshr-h): @mshr-h - [Bohan Hou](https://github.com/spectrometerHBH): @spectrometerHBH - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - [Luke Hutton](https://github.com/lhutton1): @lhutton1 @@ -192,6 +195,7 @@ We do encourage everyone to work anything they are interested in. - [Chris Sullivan](https://github.com/csullivan): @csullivan - [Anirudh Sundar Subramaniam](https://github.com/quic-sanirudh): @quic-sanirudh - [Zhixun Tan](https://github.com/phisiart): @phisiart +- [Tong Meng](https://github.com/Archermmt): @Archermmt - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - [Jorn Tuyls](https://github.com/jtuyls): @jtuyls - [Gavin Uberti](https://github.com/guberti): @guberti From d50ec2367bf2124f2958e561a7ac8d39931023f7 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 10:36:11 +0800 Subject: [PATCH 198/202] [Relax] Add NonZero op (#17453) this PR adds the NonZero op to Relax, together with ONNX frontend support --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 10 ++++- python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/set.py | 37 +++++++++++++++++++ src/relax/op/tensor/set.cc | 23 ++++++++++++ src/relax/op/tensor/set.h | 28 ++++++++++++++ tests/python/relax/test_frontend_onnx.py | 5 +++ tests/python/relax/test_op_set.py | 34 +++++++++++++++++ 7 files changed, 137 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index aa156a025fef..b9eb141bd14e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2482,6 +2482,14 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.unique(data, sorted=sorted, axis=axis) +class NonZero(OnnxOpConverter): + """Converts an onnx NonZero node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.nonzero(inputs[0]) + + class HardSigmoid(OnnxOpConverter): """Converts an onnx HardSigmoid node into an equivalent Relax expression.""" @@ -2867,7 +2875,7 @@ def _get_convert_map(): "Range": Range, "OneHot": OneHot, "Unique": Unique, - # "NonZero": NonZero, + "NonZero": NonZero, # "If": If, # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c99201e969b5..efd9997698ee 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -101,7 +101,7 @@ from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform from .search import argmax, argmin, where -from .set import unique +from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 0b86e19ce53f..c5db852ddd5d 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -110,3 +110,40 @@ def numpy_unique( return tvm.nd.array(output_sorted_numpy) output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) + + +def nonzero(x: Expr) -> Expr: + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + x : relax.Expr + The input data tensor. + + Returns + ------- + result : relax.Expr + A (n+1)-D tensor containing indices of non-zero elements. + + Note + ---- + This function is equivalent to `onnx.nonzero`. + + Examples + -------- + + .. code-block:: python + + x = [[0, 1], + [2, 0]] + nonzero(x) = [[0, 1], + [1, 0]] + + """ + return _ffi_api.nonzero(x) # type: ignore + + +@tvm.register_func("relax.run.nonzero") +def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array: + np_result = np.atleast_1d(x.numpy()).nonzero() + return tvm.nd.array(np.stack(np_result, axis=0)) diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 29d9d52c6077..c659a49afd12 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -24,6 +24,7 @@ #include "set.h" +#include #include #include @@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique") .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); +/* relax.nonzero */ +Expr nonzero(Expr x) { + static const Op& op = Op::Get("relax.nonzero"); + return Call(op, {std::move(x)}); +} + +TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); + +StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // Cheat zero dim scalar as 1-dim. + int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1; + return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.nonzero") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoNonzero) + .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index a5c7ee85bfb2..251dd1975e9f 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -29,8 +29,36 @@ namespace tvm { namespace relax { +/*! + * \brief Find the unique elements in a given tensor. + * In addition, it optionally returns + * - the indices of the input tensor that give the unique values; + * - the indices of the unique tensor that reconstruct the input tensor; + * - the number of times each unique value comes up in the input tensor. + * \param x The input tensor. + * \param sorted Whether to sort the unique elements in ascending order before + * returning as output. + * \param return_index Whether to return an additional tensor with indices for where elements in + * the unique tensor come from the original input. + * \param return_inverse Whether to return an additional tensor with indices for where elements in + * the original input ended up in the returned unique list. + * \param return_counts Whether to return an additional tensor with counts of each unique elements. + * \param axis The dimension to apply unique. + * If not specified, the unique values of the flattened input are returned. + * \return The unique elements of the array. The returned array will be sorted if `sorted` is True. + * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. + */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, PrimValue return_counts, Optional axis); + +/*! + * \brief Returns the indices of the non-zero elements of the input tensor. + * \param x The input tensor. + * \return a list of 1-D tensors containing indices of non-zero elements for each dimension. + * \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array + * instead of a tuple of 1-D arrays. + */ +Expr nonzero(Expr x); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index e3ed3a3a9d4d..57f94c8442f7 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2162,6 +2162,11 @@ def test_unique(axis: Optional[int], sorted: int): check_correctness(model) +@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)]) +def test_nonzero(shape): + verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64) + + @pytest.mark.parametrize("mode", ["DCR", "CRD"]) def test_depth_to_space(mode: Literal["DCR", "CRD"]): in_shape = [1, 8, 2, 3] diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 741d7869d52f..e9070f99fc3f 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype(): bb.normalize(relax.op.unique(x1)) +@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)]) +def test_nonzero_infer_struct_info(shape): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor(shape, "bool")) + + _check_inference( + bb, + relax.op.nonzero(x0), + relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_ndim_zero(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((), "bool")) + + _check_inference( + bb, + relax.op.nonzero(x), + relax.TensorStructInfo(ndim=2, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x1)) + + if __name__ == "__main__": tvm.testing.main() From 910ee0e852e32dd9a6e7c495229aa37847a7e473 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 10:36:30 +0800 Subject: [PATCH 199/202] [Relax] Add scatter_nd op support (#17449) Add relax scatter_nd op support and ONNX frontend support. --- include/tvm/relax/attrs/manipulate.h | 12 ++ .../tvm/relax/frontend/onnx/onnx_frontend.py | 32 ++++- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 39 +++++ .../transform/legalize_ops/manipulate.py | 17 +++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/manipulate.cc | 134 ++++++++++++++++++ src/relax/op/tensor/manipulate.h | 33 +++++ tests/python/relax/test_frontend_onnx.py | 33 ++++- tests/python/relax/test_op_manipulate.py | 25 ++++ .../test_transform_legalize_ops_manipulate.py | 62 +++++++- 11 files changed, 387 insertions(+), 3 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index ef4265d73b4b..e53ba3c36e7f 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\"."); } }; // struct ScatterElementsAttrs + +/*! \brief Attributes used in scatter_nd operators */ +struct ScatterNDAttrs : public tvm::AttrsNode { + String reduction; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + } +}; // struct ScatterNDAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b9eb141bd14e..f1fa67546c2a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -692,6 +692,36 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) +class ScatterND(OnnxOpConverter): + """Convert an onnx ScatterND node into an equivalent Relax expression.""" + + @staticmethod + def _reduction_check(attr, valid_reductions: List[str]): + reduction = attr.get("reduction", None) + reduction = reduction or b"update" + reduction = reduction.decode("utf-8") + reduction = "update" if reduction == "none" else reduction + assert ( + reduction in valid_reductions + ), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten" + + return reduction + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2]) + + @classmethod + def _impl_v16(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + class Size(OnnxOpConverter): """Convert an onnx Size node into an equivalent Relax expression.""" @@ -2827,7 +2857,7 @@ def _get_convert_map(): # "GatherND": GatherND, "Scatter": Scatter, "ScatterElements": ScatterElements, - # "ScatterND": ScatterND, + "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, # "EyeLike": EyeLike, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index efd9997698ee..84b31ccec01e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -93,6 +93,7 @@ repeat, reshape, scatter_elements, + scatter_nd, split, squeeze, tile, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index da0a09cc7b51..1673a79b08c2 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -511,3 +511,42 @@ def scatter_elements( """ return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore + + +def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr: + """Scatter updates into an array according to indices. + + Parameters + ---------- + data: relax.Expr + The input data to be updated. + + indices: relax.Expr + The index positions to update in `data`. + + updates: relax.Expr + Values to replace to. + + reduction: str + Type of reduction to apply: update, add, mul, max, min. + It is "update" by default. + + Returns + ------- + result : relax.Expr + The result has the same shape as data. + + Examples + -------- + .. code-block:: python + + # inputs + data = [1, 2, 3, 4, 5, 6, 7, 8] + indices = [[4], [3], [1], [7]] + updates = [9, 10, 11, 12] + + # output + output = [1, 11, 3, 10, 9, 6, 7, 12] + + """ + return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 1efa78c069ad..105d763403af 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.scatter_nd") +def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr: + # TODO(relax-team): Support native scatter_nd without te extern + def scatter_nd(data, indices, updates, reduction): + axes = list(range(len(indices.shape))) + indices = topi.transpose(indices, axes[-1:] + axes[:-1]) + return topi.scatter_nd(data, indices, updates, reduction) + + return bb.call_te( + scatter_nd, + call.args[0], + call.args[1], + call.args[2], + call.attrs.reduction, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e6ff35ebe56b..f7847e2af8ed 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ round, rsqrt, scatter_elements, + scatter_nd, shape_of, shape_to_tensor, sigmoid, @@ -738,6 +739,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "cumsum", "einsum", "scatter_elements", + "scatter_nd", "dataflow", "device", "divide", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2b1c6eafb652..ca7d0a0945bc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { + auto attrs = make_object(); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("relax.scatter_nd"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); + +StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { + // `call->args` contains: [data, indices, updates] + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + ICHECK_EQ(call->args.size(), 3); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input data to be a tensor. However, the given type is " + << call->args[0]->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input indices to be a tensor. However, the given type is " + << call->args[1]->GetTypeKey()); + } + if (updates_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input updates to be a tensor. However, the given type is " + << call->args[2]->GetTypeKey()); + } + + if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data and updates to have known dtype. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (data_sinfo->dtype != updates_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data to have same type with updates. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + const auto* updates_shape = updates_sinfo->shape.as(); + + if (data_shape && indices_shape && updates_shape) { + const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as(); + if (!k_dim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND needs a static shape for the last axis of indices, got " + << indices_shape->values); + } + const size_t data_ndim = data_sinfo->ndim; + const size_t indices_ndim = indices_sinfo->ndim; + const size_t updates_ndim = updates_sinfo->ndim; + if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the rank of " + "`data tensor + indices tensor - last axis of indices tensor - 1`. " + "However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values) + << ", updates: " << ShapeExpr(updates_shape->values)); + } + if (k_dim->value > static_cast(data_ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the last axis of indices tensor to be less than " + "or equal to the rank of data tensor. However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values)); + } + Array expected_updates_shape; + for (size_t i = 0; i < indices_ndim - 1; i++) { + expected_updates_shape.push_back(indices_shape->values[i]); + } + for (size_t i = k_dim->value; i < data_ndim; i++) { + expected_updates_shape.push_back(data_shape->values[i]); + } + auto check_shape = [&](const Array& expected, const Array& actual) { + if (expected.size() != actual.size()) { + return false; + } + for (size_t i = 0; i < expected.size(); i++) { + if (!analyzer->CanProve(expected[i] == actual[i])) { + return false; + } + } + return true; + }; + if (!check_shape(expected_updates_shape, updates_shape->values)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the shape with constraint: " + << "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got " + << "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: " + << ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values)); + } + } + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.scatter_nd") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 68622f1359e0..e9fa1131e803 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -173,6 +173,39 @@ Expr tile(Expr data, Array repeats); */ Expr flip(Expr data, Integer axis); +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param axis The axis along which to scatter the elements. + * \param reduction The reduction mode of the scatter elements, + * either "update", "add", "mul", "mean", "max" or "min". + * \return The computed result. + */ +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); + +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor to be updated. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param reduction The reduction mode of the scatter operation. + * Supported modes are: + * - "update": Replace the values at the indices with the update values. + * - "add": Add the update values to the existing values at the indices. + * - "mul": Multiply the existing values at the indices by the update values. + * - "max": Take the maximum of the existing value and the update value at each index. + * - "min": Take the minimum of the existing value and the update value at each index. + * \return The computed result tensor with the same shape as `data`. + * + * \note The shape of `indices` defines the shape of the scattered tensor. + * The last dimension of `indices` corresponds to the depth of each index vector. + * The shape of `updates` must match the shape of `indices` except for the last dimension, + * which must match the slice shape at each index. + */ +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 57f94c8442f7..9ac520c58e14 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -118,7 +118,6 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) - print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -523,6 +522,38 @@ def test_scatter(axis: int, name: str, opset: int): check_correctness(model, inputs={"indices": indices}, opset=opset) +@pytest.mark.parametrize("reduction", ["none", "add", "mul"]) +def test_scatter_nd(reduction): + def verify_scatter_nd(data_shape, indices_shape, updates_shape): + scatter_nd_node = helper.make_node( + "ScatterND", + ["data", "indices", "updates"], + ["output"], + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_nd_node], + "scatter_nd_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)], + ) + + model = helper.make_model(graph, producer_name="scatter_nd_test") + + indices = np.random.choice(data_shape[0], indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=16) + + verify_scatter_nd([8], [4, 1], [4]) + verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4]) + verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6]) + verify_scatter_nd([10], [5, 1], [5]) + + def test_size(): test_node = helper.make_node("Size", ["x"], ["y"]) graph = helper.make_graph( diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index ddb92725d438..e958b03e4ce6 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -45,6 +45,7 @@ def test_op_correctness(): assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum") assert relax.op.flip(x, axis=1).op == Op.get("relax.flip") assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements") + assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -3352,5 +3353,29 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): bb.normalize(relax.op.scatter_elements(d0, i0, u4)) +def test_scatter_nd_infer_struct_info(): + bb = relax.BlockBuilder() + + d0 = relax.Var("data", R.Tensor((8,), "float32")) + i0 = relax.Var("indices", R.Tensor((4, 1), "int64")) + u0 = relax.Var("updates", R.Tensor((4,), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d0, i0, u0, "update"), + relax.TensorStructInfo((8,), dtype="float32"), + ) + + d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) + i1 = relax.Var("indices", R.Tensor((2, 1), "int64")) + u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d1, i1, u1, "update"), + relax.TensorStructInfo((4, 4, 4), dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index a0ecd3c73dc9..0565b7a5790a 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import pytest import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -1739,5 +1738,66 @@ def te_layout_transform( tvm.ir.assert_structural_equal(Expected, After) +def test_scatter_nd(): + + # fmt: off + @I.ir_module + class Before: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv: R.Tensor((8,), "float32") = R.scatter_nd(data, indices, updates, reduction="update") + return gv + + After = relax.transform.LegalizeOps()(Before) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv = R.call_tir( + Expected.scatter_nd, (data, indices, updates), R.Tensor((8,), dtype="float32") + ) + return gv + + @T.prim_func(private=True) + def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, var_scatter_nd_generic: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + data = T.match_buffer(var_data, (T.int64(8),), offset_factor=1) + indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") + updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) + out_buf = T.match_buffer(var_scatter_nd_generic, (T.int64(8),)) + with T.block("root"): + T.reads() + T.writes() + T_transpose = T.alloc_buffer((T.int64(1), T.int64(4)), "int64") + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(4)): + with T.block("T_transpose"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(4), ax1) + T.reads(indices[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = indices[v_ax1, v_ax0] + with T.block("scatter_nd_generic"): + T.reads() + T.writes() + for i in range(T.int64(8)): + out_buf[i] = data[i] + for j in range(T.int64(4)): + for k in T.parallel(T.int64(1)): + out_buf[k + T_transpose[j // T.int64(4), j % T.int64(4)]] = updates[j + k] + + # fmt: on + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main() From 74ed86b5df128dffeedac1eb6bbd345b1a756327 Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Thu, 10 Oct 2024 10:37:02 +0800 Subject: [PATCH 200/202] [Relax][Frontend][Onnx] Add support for pad-2 (#17431) * fix params name bug * add support for onnx pad_v2 * Update test_frontend_onnx.py * Update onnx_frontend.py --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 29 ++++++++++ tests/python/relax/test_frontend_onnx.py | 57 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f1fa67546c2a..4770b7ce5cc5 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1582,6 +1582,35 @@ def _impl_v13(cls, bb, inputs, attr, params): class Pad(OnnxOpConverter): """Converts an onnx Pad node into an equivalent Relax expression.""" + @classmethod + def _impl_v2(cls, bb, inputs, attr, params): + pads = attr.get("pads") + pads = relax.const(_np.array(pads), inputs[0].struct_info.shape[0].dtype) + constant_value = attr.get("value") + if constant_value is None: + constant_value = 0.0 + + if isinstance(pads, relax.Constant): + pad_before, pad_after = _np.split(pads.data.numpy(), 2) + pad_before = _np.ndarray.tolist(pad_before) + pad_after = _np.ndarray.tolist(pad_after) + else: + raise ValueError("Dynamic pads are not supported yet.") + + pad_mode = attr.get("mode", b"constant").decode("utf-8") + if not pad_mode in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) + + if pad_mode == "constant": + return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after, constant_value) + elif pad_mode == "reflect": + return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") + else: + # TODO(gigiblender) Support edge mode. + raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) + @classmethod def _impl_v11(cls, bb, inputs, attr, params): pads = get_constant(inputs[1], params) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 9ac520c58e14..1b4c5d281abb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1696,6 +1696,63 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0): verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") +@pytest.mark.parametrize("dynamic", [True, False]) +def test_pad_v2(dynamic): + + if dynamic: + pytest.skip("Dynamic pad not supported") + + def verify_pad(input_shape, pads, mode="constant", value=0.0): + indata = np.random.normal(size=input_shape).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ["edge", "reflect"]: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + "Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", + inputs=["input"], + outputs=["output"], + mode="constant", + pads=pads, + value=value, + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") + check_correctness(model=model, opset=10) + + verify_pad((2, 2), [0, 1, 0, 0], "constant", 0.0) + verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0) + verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0) + verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") + + @pytest.mark.parametrize("fp_arith", [np.float16, np.float32]) @pytest.mark.parametrize("dynamic", [True, False]) def test_split(fp_arith, dynamic): From 7d2fa11bd16972368bfbaab0a872541fa76745a7 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 23:02:51 +0800 Subject: [PATCH 201/202] Try to fix windows CI conda build issue (#17457) try fix ci --- conda/build-environment.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 8eb25ce01ac7..de4e6f4234d7 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -26,7 +26,8 @@ channels: # The packages to install to the environment dependencies: - python=3.9 - - conda-build + - conda < 24.9.0 + - conda-build < 24.9.0 - git - llvmdev >=11 - numpy From 22a9d388d441dbfd917d032564e2a1bccacd5f8c Mon Sep 17 00:00:00 2001 From: ysh329 Date: Fri, 11 Oct 2024 09:17:59 +0000 Subject: [PATCH 202/202] [release] Update version to 0.18.0 on main branch --- conda/recipe/meta.yaml | 2 +- include/tvm/runtime/c_runtime_api.h | 2 +- python/tvm/_ffi/libinfo.py | 2 +- version.py | 2 +- web/package-lock.json | 4 ++-- web/package.json | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/conda/recipe/meta.yaml b/conda/recipe/meta.yaml index d4477468c79d..c5e3840ff613 100644 --- a/conda/recipe/meta.yaml +++ b/conda/recipe/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = '0.18.dev0' %} +{% set version = '0.18.0' %} {% set pkg_name = 'tvm' %} {% set cuda_tag = cuda_version | replace('.', '') %} # [cuda] {% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda] diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index d26c95e4f53c..8071020cef28 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -73,7 +73,7 @@ #endif // TVM version -#define TVM_VERSION "0.18.dev0" +#define TVM_VERSION "0.18.0" // TVM Runtime is DLPack compatible. #include diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 2ec4ba8e31be..6e39d5b33a99 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -247,4 +247,4 @@ def find_include_path(name=None, search_path=None, optional=False): # We use the version of the incoming release for code # that is under development. # The following line is set by tvm/python/update_version.py -__version__ = "0.18.dev0" +__version__ = "0.18.0" diff --git a/version.py b/version.py index a827571c6cdf..cea1ba306c57 100644 --- a/version.py +++ b/version.py @@ -44,7 +44,7 @@ # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) -__version__ = "0.18.dev0" +__version__ = "0.18.0" # --------------------------------------------------- diff --git a/web/package-lock.json b/web/package-lock.json index 751aaf2ef442..6c7e024f2236 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.18.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.18.0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/package.json b/web/package.json index a63997bb2f1c..c8d33be0b5e9 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.18.0-dev2", + "version": "0.18.0", "files": [ "lib" ],