Skip to content

Commit

Permalink
[SLM] Enable Debug Dump (mlc-ai#1499)
Browse files Browse the repository at this point in the history
This PR enables the debug dump feature. The command would be something
like

```
mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so --debug-dump debug/
```

And it would dump 6 files in the `debug/` folder:

```
debug-phase0.py
debug-phase1.py
debug-phase2.py
debug-phase3.py
debug-phase4.py
debug-final.py
```
  • Loading branch information
Hzfengsy authored Dec 28, 2023
1 parent 779b1a5 commit a9f1b72
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
tmp/
dist/
params/
debug/
*.bak
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
18 changes: 17 additions & 1 deletion python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import json
import re
from functools import partial
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -37,6 +38,14 @@ def _parse_output(path: Union[str, Path]) -> Path:
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
return path

def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path:
path = Path(path)
if not auto_create and not path.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {path}")
if auto_create and not path.is_dir():
path.mkdir(parents=True)
return path

def _check_system_lib_prefix(prefix: str) -> str:
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
if prefix == "" or re.match(pattern, prefix):
Expand All @@ -46,7 +55,7 @@ def _check_system_lib_prefix(prefix: str) -> str:
"numbers (0-9), alphabets (A-Z, a-z) and underscore (_)."
)

parser = ArgumentParser("MLC LLM Compiler")
parser = ArgumentParser("mlc_chat compile")
parser.add_argument(
"model",
type=detect_mlc_chat_config,
Expand Down Expand Up @@ -103,6 +112,12 @@ def _check_system_lib_prefix(prefix: str) -> str:
default="",
help=HELP["overrides"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--debug-dump",
type=partial(_parse_dir, auto_create=True),
default=None,
help=HELP["debug_dump"] + " (default: %(default)s)",
)
parsed = parser.parse_args(argv)
target, build_func = detect_target_and_host(parsed.device, parsed.host)
parsed.model_type = detect_model_type(parsed.model_type, parsed.model)
Expand All @@ -123,4 +138,5 @@ def _check_system_lib_prefix(prefix: str) -> str:
system_lib_prefix=parsed.system_lib_prefix,
output=parsed.output,
overrides=parsed.overrides,
debug_dump=parsed.debug_dump,
)
32 changes: 30 additions & 2 deletions python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The compilation pipeline for LLM applications."""
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List, Optional

import tvm
from tvm import IRModule
Expand Down Expand Up @@ -34,13 +35,34 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
return mod


@tvm.transform.module_pass(opt_level=0, name="DebugDump")
class _DebugDump: # pylint: disable=too-few-public-methods
"""A dummy compiler pass that does nothing but logging.
Only enabled when debug_dump is not None"""

def __init__(self, file_name: str, file_path: Optional[Path], show_meta: bool = False):
self.file_name = file_name
self.file_path = file_path
self.show_meta = show_meta

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""A dummy transformation that dumps the module to file"""
if self.file_path is not None:
# NOTE: We use debug level here to avoid spamming the console
logger.debug("Dumping IR to %s", self.file_path / self.file_name)
with open(self.file_path / self.file_name, "w", encoding="utf-8") as f:
f.write(mod.script(show_meta=self.show_meta))
return mod


@register_pipeline("mlc_llm")
def _mlc_llm_pipeline(
def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
variable_bounds: Dict[str, int] = None,
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
metadata: Dict[str, Any] = None,
ext_mods: List[nn.ExternModule] = None,
skip_gemm: bool = False,
debug_dump: Optional[Path] = None,
):
variable_bounds = variable_bounds or {}
additional_tirs = additional_tirs or {}
Expand All @@ -54,23 +76,27 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
# Phase 0. Add additional information for compilation
AttachVariableBounds(variable_bounds),
AttachAdditionalPrimFuncs(additional_tirs),
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
# Phase 1. Passes on high-level operator graph
_LogProgress("Running TVM Relax graph-level optimizations"),
FuseDequantizeTranspose(skip_gemm=skip_gemm),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
_LogProgress("Lowering to TVM TIR kernels"),
tvm.relax.transform.LegalizeOps(),
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FoldConstant(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
_DebugDump("debug-phase2.py", debug_dump, show_meta=False),
# Phase 3. Passes on TIR
_LogProgress("Running TVM TIR-level optimizations"),
FuseDequantizeMatmulEwise(),
FuseDequantizeTake(),
tvm.relax.transform.DeadCodeElimination(),
CleanUpTIRAttrs(["op_pattern"]),
_DebugDump("debug-phase3.py", debug_dump, show_meta=False),
# Phase 4. Low-level Optimizations
_LogProgress("Running TVM Dlight low-level optimizations"),
dl.ApplyDefaultSchedule(
Expand All @@ -80,6 +106,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
),
_DebugDump("debug-phase4.py", debug_dump, show_meta=False),
_LogProgress("Lowering to VM bytecode"),
LiftTIRGlobalBufferAlloc(),
tvm.tir.transform.ForceNarrowIndexToInt32(),
Expand All @@ -95,6 +122,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
tvm.relax.transform.VMBuiltinLower(),
tvm.relax.transform.VMShapeLower(),
tvm.relax.transform.AttachGlobalSymbol(),
_DebugDump("debug-final.py", debug_dump, show_meta=False),
_LogProgress("Compiling external modules"),
tvm.relax.transform.AttachExternModules(ext_mods),
_LogProgress("Compilation complete! Exporting to disk"),
Expand Down
7 changes: 6 additions & 1 deletion python/mlc_chat/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"context_window_size": """
Option to provide the maximum sequence length supported by the model.
This is usually explicitly shown as context length or context window in the model card.
If this option is not set explicitly, by default,
If this option is not set explicitly, by default,
it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`,
and the latter is usually inaccurate for some models.
""".strip(),
Expand Down Expand Up @@ -110,5 +110,10 @@
`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`,
`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly
specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128".
""".strip(),
"debug_dump": """
Specifies the directory where the compiler will store its IRs for debugging purposes
during various phases of compilation. By default, this is set to `None`, indicating
that debug dumping is disabled.
""".strip(),
}
8 changes: 7 additions & 1 deletion python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

from tvm import IRModule, relax, tir
from tvm.ir.transform import Pass
Expand Down Expand Up @@ -97,6 +97,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes
system_lib_prefix: str
output: Path
overrides: ModelConfigOverride
debug_dump: Optional[Path]

def __post_init__(self) -> None:
self.opt.update(self.target)
Expand All @@ -113,6 +114,8 @@ def display(self) -> None:
print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out)
print(f" {bold('--output'):<25} {self.output}", file=out)
print(f" {bold('--overrides'):<25} {self.overrides}", file=out)
# As it's debug only, no need to display
# print(f" {bold('--debug-dump'):<25} {self.debug_dump}", file=out)
print(out.getvalue().rstrip())


Expand Down Expand Up @@ -200,6 +203,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
additional_tirs=additional_tirs,
ext_mods=ext_mods,
metadata=metadata,
debug_dump=args.debug_dump,
),
)
logger.info("Generated: %s", bold(str(args.output)))
Expand All @@ -215,6 +219,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
system_lib_prefix: str,
output: Path,
overrides: ModelConfigOverride,
debug_dump: Optional[Path] = None,
):
"""Compile a model given its configuration and quantization format to a specific target."""
if "model_config" in config:
Expand All @@ -231,6 +236,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
system_lib_prefix,
output,
overrides,
debug_dump,
)
args.display()
_compile(args, model_config)
Expand Down

0 comments on commit a9f1b72

Please sign in to comment.