Skip to content

Commit

Permalink
Add debug dump
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 12, 2023
1 parent eb6f8a4 commit 5bfd964
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ def _parse_args():
args = argparse.ArgumentParser()
args.add_argument("--target", type=str, default="apple/m2-gpu")
args.add_argument("--db-path", type=str, default="log_db/")
args.add_argument("--from-checkpt", type=str, choices=["deploy"], default="")
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--use-cache",
type=int,
default=1,
help="Whether to use previously pickled IRModule and skip trace.",
)
args.add_argument("--debug-dump", action="store_true", default=False)

args.add_argument("--show-build-stage", action="store_true", default=False)
parsed = args.parse_args()

if parsed.target == "webgpu":
Expand Down Expand Up @@ -68,8 +67,9 @@ def legalize_and_lift_params(
mod = relax.pipeline.get_pipeline()(mod)
mod = relax.transform.RemoveUnusedFunctions(entry_funcs)(mod)
mod = relax.transform.LiftTransformParams()(mod)
if args.show_build_stage:
mod.show()

debug_dump(mod_deploy, "mod_lift_params.py", args)

mod_transform, mod_deploy = utils.split_transform_deploy_mod(
mod, model_names, entry_funcs
)
Expand All @@ -80,12 +80,25 @@ def legalize_and_lift_params(
return mod_deploy


def debug_dump(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
dump_path = os.path.join(
args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")

def build(mod: tvm.IRModule, args: Dict) -> None:
from tvm import meta_schedule as ms

db = ms.database.create(work_dir=args.db_path)
with args.target, db, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod)

debug_dump(mod_deploy, "mod_build_stage.py", args)

ex = relax.build(mod_deploy, args.target)

target_kind = args.target.kind.default_keys[0]
Expand All @@ -101,6 +114,7 @@ def build(mod: tvm.IRModule, args: Dict) -> None:
if __name__ == "__main__":
ARGS = _parse_args()
os.makedirs(ARGS.artifact_path, exist_ok=True)
os.makedirs(os.path.join(ARGS.artifact_path, "debug"), exist_ok=True)
torch_dev_key = utils.detect_available_torch_device()
cache_path = os.path.join(ARGS.artifact_path, "mod_cache_before_build.pkl")
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
Expand Down

0 comments on commit 5bfd964

Please sign in to comment.