From 779b1a5819bf6c517d29daa55de84b782de54bf9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 27 Dec 2023 20:09:45 -0500 Subject: [PATCH] Reorganize folder structure (#1502) This PR reorganizes the existing `python/` folder structure for better clarify. - `mlc_chat/model` <- `mlc_chat/compiler/model` - `mlc_chat/quantization` <- `mlc_chat/compiler/quantization` - `mlc_chat/loader` <- `mlc_chat/compiler/loader` - `mlc_chat/operator` <- `mlc_chat/compiler/*_op.py` - `mlc_chat/compiler_pass` <- `mlc_chat/compiler/compiler_pass.py` - `mlc_chat/interface` <- `mlc_chat/compiler/{compile/gen_config/convert_weight}.py` --- python/mlc_chat/base.py | 36 +---- python/mlc_chat/callback.py | 24 +++- python/mlc_chat/chat_module.py | 2 +- python/mlc_chat/cli/compile.py | 18 +-- python/mlc_chat/cli/convert_weight.py | 14 +- python/mlc_chat/cli/gen_config.py | 10 +- python/mlc_chat/compiler/__init__.py | 14 -- python/mlc_chat/compiler/extern/flashinfer.py | 28 ---- .../mlc_chat/compiler/flags_optimization.py | 96 ------------- python/mlc_chat/compiler/model/__init__.py | 3 - .../{compiler => }/compiler_pass/__init__.py | 0 .../compiler_pass/attach_to_ir_module.py | 0 .../compiler_pass/clean_up_tir_attrs.py | 0 .../compiler_pass/estimate_memory_usage.py | 0 .../fuse_dequantize_matmul_ewise.py | 0 .../compiler_pass/fuse_dequantize_take.py | 0 .../fuse_dequantize_transpose.py | 0 .../compiler_pass/fuse_transpose_matmul.py | 0 .../compiler_pass/lift_global_buffer_alloc.py | 2 +- .../{compiler => }/compiler_pass/pipeline.py | 0 python/mlc_chat/{compiler => }/help.py | 7 +- .../{compiler => interface}/compile.py | 103 ++++++++++++-- .../{compiler => interface}/convert_weight.py | 17 +-- .../flags_model_config_override.py | 0 .../{compiler => interface}/gen_config.py | 9 +- .../{compiler => }/loader/__init__.py | 0 .../loader/huggingface_loader.py | 6 +- .../mlc_chat/{compiler => }/loader/loader.py | 0 .../mlc_chat/{compiler => }/loader/mapping.py | 0 .../mlc_chat/{compiler => }/loader/stats.py | 0 .../mlc_chat/{compiler => }/loader/utils.py | 0 python/mlc_chat/model/__init__.py | 3 + .../{compiler => }/model/gpt2/__init__.py | 0 .../{compiler => }/model/gpt2/gpt2_loader.py | 10 +- .../{compiler => }/model/gpt2/gpt2_model.py | 6 +- .../model/gpt2/gpt2_quantization.py | 5 +- .../model/gpt_bigcode/__init__.py | 0 .../model/gpt_bigcode/gpt_bigcode_loader.py | 10 +- .../model/gpt_bigcode/gpt_bigcode_model.py | 8 +- .../gpt_bigcode/gpt_bigcode_quantization.py | 5 +- .../{compiler => }/model/gpt_neox/__init__.py | 0 .../model/gpt_neox/gpt_neox_loader.py | 12 +- .../model/gpt_neox/gpt_neox_model.py | 4 +- .../model/gpt_neox/gpt_neox_quantization.py | 5 +- .../{compiler => }/model/llama/__init__.py | 0 .../model/llama/llama_loader.py | 15 ++- .../{compiler => }/model/llama/llama_model.py | 15 +-- .../model/llama/llama_quantization.py | 5 +- .../{compiler => }/model/mistral/__init__.py | 0 .../model/mistral/mistral_loader.py | 15 ++- .../model/mistral/mistral_model.py | 6 +- .../model/mistral/mistral_quantization.py | 5 +- python/mlc_chat/model/model.py | 126 +++++++++++++++++ .../model/model.py => model/model_preset.py} | 127 +----------------- .../{compiler/extern => operator}/__init__.py | 2 + .../extern_op.py => operator/attention.py} | 12 +- .../{compiler/extern => operator}/extern.py | 20 +-- .../position_embedding.py} | 0 .../{compiler => }/quantization/__init__.py | 0 .../quantization/awq_quantization.py | 4 +- .../quantization/group_quantization.py | 7 +- .../quantization/no_quantization.py | 0 .../quantization/quantization.py | 0 .../{compiler => }/quantization/utils.py | 0 python/mlc_chat/rest.py | 2 +- python/mlc_chat/support/auto_config.py | 33 +++-- python/mlc_chat/support/random.py | 16 +++ .../{compiler => support}/tensor_parallel.py | 0 tests/python/loader/test_awq.py | 5 +- tests/python/loader/test_huggingface.py | 4 +- tests/python/model/test_gpt2.py | 2 +- tests/python/model/test_gptNeox.py | 2 +- tests/python/model/test_llama.py | 2 +- tests/python/model/test_llama_quantization.py | 5 +- tests/python/model/test_mistral.py | 2 +- .../quantization/test_awq_quantization.py | 5 +- .../quantization/test_group_quantization.py | 5 +- 77 files changed, 429 insertions(+), 470 deletions(-) delete mode 100644 python/mlc_chat/compiler/__init__.py delete mode 100644 python/mlc_chat/compiler/extern/flashinfer.py delete mode 100644 python/mlc_chat/compiler/flags_optimization.py delete mode 100644 python/mlc_chat/compiler/model/__init__.py rename python/mlc_chat/{compiler => }/compiler_pass/__init__.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/attach_to_ir_module.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/clean_up_tir_attrs.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/estimate_memory_usage.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/fuse_dequantize_matmul_ewise.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/fuse_dequantize_take.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/fuse_dequantize_transpose.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/fuse_transpose_matmul.py (100%) rename python/mlc_chat/{compiler => }/compiler_pass/lift_global_buffer_alloc.py (99%) rename python/mlc_chat/{compiler => }/compiler_pass/pipeline.py (100%) rename python/mlc_chat/{compiler => }/help.py (97%) rename python/mlc_chat/{compiler => interface}/compile.py (67%) rename python/mlc_chat/{compiler => interface}/convert_weight.py (93%) rename python/mlc_chat/{compiler => interface}/flags_model_config_override.py (100%) rename python/mlc_chat/{compiler => interface}/gen_config.py (97%) rename python/mlc_chat/{compiler => }/loader/__init__.py (100%) rename python/mlc_chat/{compiler => }/loader/huggingface_loader.py (99%) rename python/mlc_chat/{compiler => }/loader/loader.py (100%) rename python/mlc_chat/{compiler => }/loader/mapping.py (100%) rename python/mlc_chat/{compiler => }/loader/stats.py (100%) rename python/mlc_chat/{compiler => }/loader/utils.py (100%) create mode 100644 python/mlc_chat/model/__init__.py rename python/mlc_chat/{compiler => }/model/gpt2/__init__.py (100%) rename python/mlc_chat/{compiler => }/model/gpt2/gpt2_loader.py (91%) rename python/mlc_chat/{compiler => }/model/gpt2/gpt2_model.py (98%) rename python/mlc_chat/{compiler => }/model/gpt2/gpt2_quantization.py (92%) rename python/mlc_chat/{compiler => }/model/gpt_bigcode/__init__.py (100%) rename python/mlc_chat/{compiler => }/model/gpt_bigcode/gpt_bigcode_loader.py (85%) rename python/mlc_chat/{compiler => }/model/gpt_bigcode/gpt_bigcode_model.py (98%) rename python/mlc_chat/{compiler => }/model/gpt_bigcode/gpt_bigcode_quantization.py (92%) rename python/mlc_chat/{compiler => }/model/gpt_neox/__init__.py (100%) rename python/mlc_chat/{compiler => }/model/gpt_neox/gpt_neox_loader.py (90%) rename python/mlc_chat/{compiler => }/model/gpt_neox/gpt_neox_model.py (99%) rename python/mlc_chat/{compiler => }/model/gpt_neox/gpt_neox_quantization.py (90%) rename python/mlc_chat/{compiler => }/model/llama/__init__.py (100%) rename python/mlc_chat/{compiler => }/model/llama/llama_loader.py (93%) rename python/mlc_chat/{compiler => }/model/llama/llama_model.py (96%) rename python/mlc_chat/{compiler => }/model/llama/llama_quantization.py (92%) rename python/mlc_chat/{compiler => }/model/mistral/__init__.py (100%) rename python/mlc_chat/{compiler => }/model/mistral/mistral_loader.py (93%) rename python/mlc_chat/{compiler => }/model/mistral/mistral_model.py (99%) rename python/mlc_chat/{compiler => }/model/mistral/mistral_quantization.py (91%) create mode 100644 python/mlc_chat/model/model.py rename python/mlc_chat/{compiler/model/model.py => model/model_preset.py} (61%) rename python/mlc_chat/{compiler/extern => operator}/__init__.py (52%) rename python/mlc_chat/{compiler/model/extern_op.py => operator/attention.py} (93%) rename python/mlc_chat/{compiler/extern => operator}/extern.py (79%) rename python/mlc_chat/{compiler/model/position_embedding_op.py => operator/position_embedding.py} (100%) rename python/mlc_chat/{compiler => }/quantization/__init__.py (100%) rename python/mlc_chat/{compiler => }/quantization/awq_quantization.py (99%) rename python/mlc_chat/{compiler => }/quantization/group_quantization.py (99%) rename python/mlc_chat/{compiler => }/quantization/no_quantization.py (100%) rename python/mlc_chat/{compiler => }/quantization/quantization.py (100%) rename python/mlc_chat/{compiler => }/quantization/utils.py (100%) create mode 100644 python/mlc_chat/support/random.py rename python/mlc_chat/{compiler => support}/tensor_parallel.py (100%) diff --git a/python/mlc_chat/base.py b/python/mlc_chat/base.py index 8980330977..7dd44f1998 100644 --- a/python/mlc_chat/base.py +++ b/python/mlc_chat/base.py @@ -10,7 +10,7 @@ def _load_mlc_llm_lib(): - """Load mlc llm lib""" + """Load MLC LLM lib""" if sys.platform.startswith("win32") and sys.version_info >= (3, 8): for path in libinfo.get_dll_directories(): os.add_dll_directory(path) @@ -24,37 +24,3 @@ def _load_mlc_llm_lib(): # only load once here if os.environ.get("SKIP_LOADING_MLCLLM_SO", "0") == "0": _LIB, _LIB_PATH = _load_mlc_llm_lib() - - -def get_delta_message(curr_message: str, new_message: str) -> str: - r"""Given the current message and the new message, compute the delta message - (the newly generated part, the diff of the new message from the current message). - - Parameters - ---------- - curr_message : str - The message generated in the previous round. - new_message : str - The message generated in the new round. - - Returns - ------- - delta_message : str - The diff of the new message from the current message (the newly generated part). - """ - f_get_delta_message = tvm.get_global_func("mlc.get_delta_message") - return f_get_delta_message(curr_message, new_message) - - -def set_global_random_seed(seed): - """Set global random seed for python, numpy, torch and tvm.""" - if "numpy" in sys.modules: - sys.modules["numpy"].random.seed(seed) - if "torch" in sys.modules: - sys.modules["torch"].manual_seed(seed) - if "random" in sys.modules: - sys.modules["random"].seed(seed) - if "tvm" in sys.modules: - set_seed = sys.modules["tvm"].get_global_func("mlc.random.set_seed") - if set_seed: - set_seed(seed) diff --git a/python/mlc_chat/callback.py b/python/mlc_chat/callback.py index c317043fa7..bf63c31b9e 100644 --- a/python/mlc_chat/callback.py +++ b/python/mlc_chat/callback.py @@ -3,7 +3,27 @@ from queue import Queue from typing import Optional -from .base import get_delta_message + +def _get_delta_message(curr_message: str, new_message: str) -> str: + r"""Given the current message and the new message, compute the delta message + (the newly generated part, the diff of the new message from the current message). + + Parameters + ---------- + curr_message : str + The message generated in the previous round. + new_message : str + The message generated in the new round. + + Returns + ------- + delta_message : str + The diff of the new message from the current message (the newly generated part). + """ + from tvm._ffi import get_global_func # pylint: disable=import-outside-toplevel + + f_get_delta_message = get_global_func("mlc.get_delta_message") + return f_get_delta_message(curr_message, new_message) class DeltaCallback: @@ -27,7 +47,7 @@ def __call__(self, message: str = "", stopped: bool = False): self.stopped_callback() self.curr_message = "" else: - delta = get_delta_message(self.curr_message, message) + delta = _get_delta_message(self.curr_message, message) self.curr_message = message self.delta_callback(delta) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 86228380c8..768398fde4 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -15,7 +15,7 @@ from mlc_chat.support import logging from mlc_chat.support.auto_device import detect_device -from . import base # pylint: disable=unused-import +from . import base as _ if TYPE_CHECKING: from .interface.openai_api import ChatMessage diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 72d8f8d582..9a22927bb4 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -5,22 +5,24 @@ from pathlib import Path from typing import Union -from mlc_chat.compiler import ( # pylint: disable=redefined-builtin - HELP, - MODELS, - QUANTIZATION, +from mlc_chat.help import HELP +from mlc_chat.interface.compile import ( # pylint: disable=redefined-builtin ModelConfigOverride, OptimizationFlags, compile, ) - -from ..support.argparse import ArgumentParser -from ..support.auto_config import ( +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import ( detect_mlc_chat_config, detect_model_type, detect_quantization, ) -from ..support.auto_target import detect_system_lib_prefix, detect_target_and_host +from mlc_chat.support.auto_target import ( + detect_system_lib_prefix, + detect_target_and_host, +) def main(argv): diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_chat/cli/convert_weight.py index 99c9d3832b..5e97cc7486 100644 --- a/python/mlc_chat/cli/convert_weight.py +++ b/python/mlc_chat/cli/convert_weight.py @@ -3,12 +3,14 @@ from pathlib import Path from typing import Union -from mlc_chat.compiler import HELP, MODELS, QUANTIZATION, convert_weight - -from ..support.argparse import ArgumentParser -from ..support.auto_config import detect_config, detect_model_type -from ..support.auto_device import detect_device -from ..support.auto_weight import detect_weight +from mlc_chat.help import HELP +from mlc_chat.interface.convert_weight import convert_weight +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import detect_config, detect_model_type +from mlc_chat.support.auto_device import detect_device +from mlc_chat.support.auto_weight import detect_weight def main(argv): diff --git a/python/mlc_chat/cli/gen_config.py b/python/mlc_chat/cli/gen_config.py index 89a1dfe558..4ff09b5a8a 100644 --- a/python/mlc_chat/cli/gen_config.py +++ b/python/mlc_chat/cli/gen_config.py @@ -2,10 +2,12 @@ from pathlib import Path from typing import Union -from mlc_chat.compiler import CONV_TEMPLATES, HELP, MODELS, QUANTIZATION, gen_config - -from ..support.argparse import ArgumentParser -from ..support.auto_config import detect_config, detect_model_type +from mlc_chat.help import HELP +from mlc_chat.interface.gen_config import CONV_TEMPLATES, gen_config +from mlc_chat.model import MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.support.argparse import ArgumentParser +from mlc_chat.support.auto_config import detect_config, detect_model_type def main(argv): diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py deleted file mode 100644 index eb807e04a2..0000000000 --- a/python/mlc_chat/compiler/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency, -but users could optionally import it if they want to use the compiler. -""" -from . import compiler_pass -from .compile import CompileArgs, compile # pylint: disable=redefined-builtin -from .convert_weight import ConversionArgs, convert_weight -from .flags_model_config_override import ModelConfigOverride -from .flags_optimization import OptimizationFlags -from .gen_config import CONV_TEMPLATES, gen_config -from .help import HELP -from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping -from .model import MODEL_PRESETS, MODELS, Model -from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_chat/compiler/extern/flashinfer.py b/python/mlc_chat/compiler/extern/flashinfer.py deleted file mode 100644 index cd4f7805cd..0000000000 --- a/python/mlc_chat/compiler/extern/flashinfer.py +++ /dev/null @@ -1,28 +0,0 @@ -"""FlashInfer library.""" -import dataclasses - - -@dataclasses.dataclass -class FlashInfer: - """A fast kernel library for LLM inference.""" - - rope_scale: float = 1.0 - rope_theta: float = 10000.0 - - def configure( - self, - rope_scale: float, - rope_theta: float, - ): - """Configure FlashInfer as an external operator - - Parameters - ---------- - rope_scale : float - Scaling factor for the RoPE embedding. - - rope_theta : float - The base period of the RoPE embedding. - """ - self.rope_scale = rope_scale - self.rope_theta = rope_theta diff --git a/python/mlc_chat/compiler/flags_optimization.py b/python/mlc_chat/compiler/flags_optimization.py deleted file mode 100644 index 3819c62b63..0000000000 --- a/python/mlc_chat/compiler/flags_optimization.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Optimization flags""" -import argparse -import dataclasses -from io import StringIO - -from tvm.target import Target - -from mlc_chat.support.logging import getLogger - -logger = getLogger(__name__) - - -@dataclasses.dataclass -class OptimizationFlags: - """Optimization flags""" - - flashinfer: bool = False - cublas_gemm: bool = False - cudagraph: bool = False - - def __repr__(self) -> str: - out = StringIO() - print(f"flashinfer={int(self.flashinfer)}", file=out, end="") - print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") - print(f";cudagraph={int(self.cudagraph)}", file=out, end="") - return out.getvalue().rstrip() - - @staticmethod - def from_str(source: str) -> "OptimizationFlags": - """Parse optimization flags from a string.""" - - if source in OPT_FLAG_PRESET: - return OPT_FLAG_PRESET[source] - - def boolean(value: str) -> bool: - if value == "0": - return False - if value == "1": - return True - raise ValueError(f"Invalid boolean value: {value}") - - parser = argparse.ArgumentParser(description="optimization flags") - parser.add_argument("--flashinfer", type=boolean, default=True) - parser.add_argument("--cublas_gemm", type=boolean, default=False) - parser.add_argument("--cudagraph", type=boolean, default=False) - results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) - return OptimizationFlags( - flashinfer=results.flashinfer, - cublas_gemm=results.cublas_gemm, - cudagraph=results.cudagraph, - ) - - def update(self, target: Target) -> None: - """Update optimization flags based on additional information.""" - - def _flashinfer(target) -> bool: - from mlc_chat.support.auto_target import ( # pylint: disable=import-outside-toplevel - detect_cuda_arch_list, - ) - - if not self.flashinfer: - return False - if target.kind.name != "cuda": - return False - arch_list = detect_cuda_arch_list(target) - for arch in arch_list: - if arch < 80: - logger.warning("flashinfer is not supported on CUDA arch < 80") - return False - return True - - self.flashinfer = _flashinfer(target) - - -OPT_FLAG_PRESET = { - "O0": OptimizationFlags( - flashinfer=False, - cublas_gemm=False, - cudagraph=False, - ), - "O1": OptimizationFlags( - flashinfer=False, - cublas_gemm=True, - cudagraph=False, - ), - "O2": OptimizationFlags( - flashinfer=False, - cublas_gemm=True, - cudagraph=False, - ), - "O3": OptimizationFlags( - flashinfer=True, - cublas_gemm=True, - cudagraph=True, - ), -} diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py deleted file mode 100644 index 87dcd49097..0000000000 --- a/python/mlc_chat/compiler/model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Model definition for the compiler.""" -from . import llama, mistral -from .model import MODEL_PRESETS, MODELS, Model diff --git a/python/mlc_chat/compiler/compiler_pass/__init__.py b/python/mlc_chat/compiler_pass/__init__.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/__init__.py rename to python/mlc_chat/compiler_pass/__init__.py diff --git a/python/mlc_chat/compiler/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/attach_to_ir_module.py rename to python/mlc_chat/compiler_pass/attach_to_ir_module.py diff --git a/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py b/python/mlc_chat/compiler_pass/clean_up_tir_attrs.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py rename to python/mlc_chat/compiler_pass/clean_up_tir_attrs.py diff --git a/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py b/python/mlc_chat/compiler_pass/estimate_memory_usage.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py rename to python/mlc_chat/compiler_pass/estimate_memory_usage.py diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py rename to python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_take.py b/python/mlc_chat/compiler_pass/fuse_dequantize_take.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/fuse_dequantize_take.py rename to python/mlc_chat/compiler_pass/fuse_dequantize_take.py diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_transpose.py b/python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/fuse_dequantize_transpose.py rename to python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py b/python/mlc_chat/compiler_pass/fuse_transpose_matmul.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py rename to python/mlc_chat/compiler_pass/fuse_transpose_matmul.py diff --git a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py similarity index 99% rename from python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py rename to python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py index ebf8f27acf..bf709bce04 100644 --- a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py +++ b/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py @@ -169,7 +169,7 @@ def _resolve_tir_var_mapping( # pylint: disable=too-many-locals ret_tensors = call.sinfo_args[0] ret_tensors = ( - [ret_tensors] + [ret_tensors] # type: ignore[assignment] if isinstance(ret_tensors, relax.TensorStructInfo) else list(ret_tensors.fields) ) diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py similarity index 100% rename from python/mlc_chat/compiler/compiler_pass/pipeline.py rename to python/mlc_chat/compiler_pass/pipeline.py diff --git a/python/mlc_chat/compiler/help.py b/python/mlc_chat/help.py similarity index 97% rename from python/mlc_chat/compiler/help.py rename to python/mlc_chat/help.py index 009ef8d7d3..da3c002c83 100644 --- a/python/mlc_chat/compiler/help.py +++ b/python/mlc_chat/help.py @@ -1,6 +1,4 @@ """Help message for CLI arguments.""" -from .model import MODEL_PRESETS - HELP = { "config": ( """ @@ -17,10 +15,7 @@ as well as an optional `generation_config.json` provides additional default configuration for text generation. Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. - -Pre-defined model architectures include """ - + ", ".join(f'"{preset}"' for preset in MODEL_PRESETS) - + "." +""" ).strip(), "quantization": """ The quantization mode we use to compile. If unprovided, will infer from `model`. diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/interface/compile.py similarity index 67% rename from python/mlc_chat/compiler/compile.py rename to python/mlc_chat/interface/compile.py index 09fd0d6174..81b147956a 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -9,18 +9,81 @@ from tvm.relax.frontend import nn from tvm.target import Target -from ..support import logging -from ..support.config import ConfigBase -from ..support.style import bold -from . import extern +from mlc_chat import compiler_pass as _ +from mlc_chat import operator as op_ext +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import argparse, logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + from .flags_model_config_override import ModelConfigOverride -from .flags_optimization import OptimizationFlags -from .model import Model -from .quantization import Quantization logger = logging.getLogger(__name__) +@dataclasses.dataclass +class OptimizationFlags: + """Optimization flags""" + + flashinfer: bool = False + cublas_gemm: bool = False + cudagraph: bool = False + + def __repr__(self) -> str: + out = StringIO() + print(f"flashinfer={int(self.flashinfer)}", file=out, end="") + print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") + print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "OptimizationFlags": + """Parse optimization flags from a string.""" + + if source in OPT_FLAG_PRESET: + return OPT_FLAG_PRESET[source] + + def boolean(value: str) -> bool: + if value == "0": + return False + if value == "1": + return True + raise ValueError(f"Invalid boolean value: {value}") + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--flashinfer", type=boolean, default=True) + parser.add_argument("--cublas_gemm", type=boolean, default=False) + parser.add_argument("--cudagraph", type=boolean, default=False) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return OptimizationFlags( + flashinfer=results.flashinfer, + cublas_gemm=results.cublas_gemm, + cudagraph=results.cudagraph, + ) + + def update(self, target: Target) -> None: + """Update optimization flags based on additional information.""" + + def _flashinfer(target) -> bool: + from mlc_chat.support.auto_target import ( # pylint: disable=import-outside-toplevel + detect_cuda_arch_list, + ) + + if not self.flashinfer: + return False + if target.kind.name != "cuda": + return False + arch_list = detect_cuda_arch_list(target) + for arch in arch_list: + if arch < 80: + logger.warning("flashinfer is not supported on CUDA arch < 80") + return False + return True + + self.flashinfer = _flashinfer(target) + + @dataclasses.dataclass class CompileArgs: # pylint: disable=too-many-instance-attributes """Arguments to MLC LLM's compiler.""" @@ -100,7 +163,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: args.overrides.apply(model_config) with args.target: - extern.enable( + op_ext.enable( target=args.target, flashinfer=args.opt.flashinfer, ) @@ -171,3 +234,27 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin ) args.display() _compile(args, model_config) + + +OPT_FLAG_PRESET = { + "O0": OptimizationFlags( + flashinfer=False, + cublas_gemm=False, + cudagraph=False, + ), + "O1": OptimizationFlags( + flashinfer=False, + cublas_gemm=True, + cudagraph=False, + ), + "O2": OptimizationFlags( + flashinfer=False, + cublas_gemm=True, + cudagraph=False, + ), + "O3": OptimizationFlags( + flashinfer=True, + cublas_gemm=True, + cudagraph=True, + ), +} diff --git a/python/mlc_chat/compiler/convert_weight.py b/python/mlc_chat/interface/convert_weight.py similarity index 93% rename from python/mlc_chat/compiler/convert_weight.py rename to python/mlc_chat/interface/convert_weight.py index a2bd5c4523..5f0441450b 100644 --- a/python/mlc_chat/compiler/convert_weight.py +++ b/python/mlc_chat/interface/convert_weight.py @@ -11,13 +11,11 @@ from tvm.runtime import cpu as cpu_device from tvm.target import Target -from mlc_chat.support import tqdm - -from ..support import logging -from ..support.style import bold, green -from .loader import LOADER -from .model import Model -from .quantization import Quantization +from mlc_chat.loader import LOADER +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import logging, tqdm +from mlc_chat.support.style import bold, green logger = logging.getLogger(__name__) @@ -66,7 +64,10 @@ def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-loc model, quantize_map = args.model.quantize[args.quantization.kind]( model_config, args.quantization ) - _, _named_params = model.export_tvm(spec=model.get_default_spec()) # type: ignore[attr-defined] + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) named_params = dict(_named_params) def _check_param(name: str, param: NDArray): diff --git a/python/mlc_chat/compiler/flags_model_config_override.py b/python/mlc_chat/interface/flags_model_config_override.py similarity index 100% rename from python/mlc_chat/compiler/flags_model_config_override.py rename to python/mlc_chat/interface/flags_model_config_override.py diff --git a/python/mlc_chat/compiler/gen_config.py b/python/mlc_chat/interface/gen_config.py similarity index 97% rename from python/mlc_chat/compiler/gen_config.py rename to python/mlc_chat/interface/gen_config.py index 7a1be20e82..4df92b8f19 100644 --- a/python/mlc_chat/compiler/gen_config.py +++ b/python/mlc_chat/interface/gen_config.py @@ -5,11 +5,12 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from ..support import logging -from ..support.style import bold, green, red +from mlc_chat.model import Model +from mlc_chat.quantization import Quantization +from mlc_chat.support import logging +from mlc_chat.support.style import bold, green, red + from .flags_model_config_override import ModelConfigOverride -from .model import Model -from .quantization import Quantization logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/loader/__init__.py b/python/mlc_chat/loader/__init__.py similarity index 100% rename from python/mlc_chat/compiler/loader/__init__.py rename to python/mlc_chat/loader/__init__.py diff --git a/python/mlc_chat/compiler/loader/huggingface_loader.py b/python/mlc_chat/loader/huggingface_loader.py similarity index 99% rename from python/mlc_chat/compiler/loader/huggingface_loader.py rename to python/mlc_chat/loader/huggingface_loader.py index c6d03b8f46..c9dc75b5ce 100644 --- a/python/mlc_chat/compiler/loader/huggingface_loader.py +++ b/python/mlc_chat/loader/huggingface_loader.py @@ -1,5 +1,4 @@ """A weight loader for HuggingFace's PyTorch format""" - import gc import json from collections import OrderedDict, defaultdict @@ -11,8 +10,9 @@ from tvm.runtime import Device, NDArray from tvm.runtime.ndarray import array as as_ndarray -from ...support import logging -from ...support.style import bold +from mlc_chat.support import logging +from mlc_chat.support.style import bold + from .mapping import ExternMapping, QuantizeMapping from .stats import Stats from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard diff --git a/python/mlc_chat/compiler/loader/loader.py b/python/mlc_chat/loader/loader.py similarity index 100% rename from python/mlc_chat/compiler/loader/loader.py rename to python/mlc_chat/loader/loader.py diff --git a/python/mlc_chat/compiler/loader/mapping.py b/python/mlc_chat/loader/mapping.py similarity index 100% rename from python/mlc_chat/compiler/loader/mapping.py rename to python/mlc_chat/loader/mapping.py diff --git a/python/mlc_chat/compiler/loader/stats.py b/python/mlc_chat/loader/stats.py similarity index 100% rename from python/mlc_chat/compiler/loader/stats.py rename to python/mlc_chat/loader/stats.py diff --git a/python/mlc_chat/compiler/loader/utils.py b/python/mlc_chat/loader/utils.py similarity index 100% rename from python/mlc_chat/compiler/loader/utils.py rename to python/mlc_chat/loader/utils.py diff --git a/python/mlc_chat/model/__init__.py b/python/mlc_chat/model/__init__.py new file mode 100644 index 0000000000..d7b0baaa71 --- /dev/null +++ b/python/mlc_chat/model/__init__.py @@ -0,0 +1,3 @@ +"""Model definition for the compiler.""" +from .model import MODELS, Model +from .model_preset import MODEL_PRESETS diff --git a/python/mlc_chat/compiler/model/gpt2/__init__.py b/python/mlc_chat/model/gpt2/__init__.py similarity index 100% rename from python/mlc_chat/compiler/model/gpt2/__init__.py rename to python/mlc_chat/model/gpt2/__init__.py diff --git a/python/mlc_chat/compiler/model/gpt2/gpt2_loader.py b/python/mlc_chat/model/gpt2/gpt2_loader.py similarity index 91% rename from python/mlc_chat/compiler/model/gpt2/gpt2_loader.py rename to python/mlc_chat/model/gpt2/gpt2_loader.py index 9f574c3de7..43c4ff14e1 100644 --- a/python/mlc_chat/compiler/model/gpt2/gpt2_loader.py +++ b/python/mlc_chat/model/gpt2/gpt2_loader.py @@ -4,8 +4,9 @@ """ import functools -from ...loader import ExternMapping -from ...quantization import Quantization +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + from .gpt2_model import GPT2Config, GPT2LMHeadModel @@ -29,8 +30,9 @@ def huggingface(model_config: GPT2Config, quantization: Quantization) -> ExternM model = GPT2LMHeadModel(model_config) if quantization is not None: model.to(quantization.model_dtype) - _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking - spec=model.get_default_spec() + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, ) named_parameters = dict(_named_params) diff --git a/python/mlc_chat/compiler/model/gpt2/gpt2_model.py b/python/mlc_chat/model/gpt2/gpt2_model.py similarity index 98% rename from python/mlc_chat/compiler/model/gpt2/gpt2_model.py rename to python/mlc_chat/model/gpt2/gpt2_model.py index 6128b2a2f6..b6a122dfd0 100644 --- a/python/mlc_chat/compiler/model/gpt2/gpt2_model.py +++ b/python/mlc_chat/model/gpt2/gpt2_model.py @@ -10,9 +10,9 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from ....support import logging -from ....support.config import ConfigBase -from ....support.style import bold +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/model/gpt2/gpt2_quantization.py b/python/mlc_chat/model/gpt2/gpt2_quantization.py similarity index 92% rename from python/mlc_chat/compiler/model/gpt2/gpt2_quantization.py rename to python/mlc_chat/model/gpt2/gpt2_quantization.py index 3c5f163396..f042239966 100644 --- a/python/mlc_chat/compiler/model/gpt2/gpt2_quantization.py +++ b/python/mlc_chat/model/gpt2/gpt2_quantization.py @@ -4,8 +4,9 @@ from tvm.relax.frontend import nn -from ...loader import QuantizeMapping -from ...quantization import AWQQuantize, GroupQuantize, NoQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, GroupQuantize, NoQuantize + from .gpt2_model import GPT2Config, GPT2LMHeadModel diff --git a/python/mlc_chat/compiler/model/gpt_bigcode/__init__.py b/python/mlc_chat/model/gpt_bigcode/__init__.py similarity index 100% rename from python/mlc_chat/compiler/model/gpt_bigcode/__init__.py rename to python/mlc_chat/model/gpt_bigcode/__init__.py diff --git a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_loader.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py similarity index 85% rename from python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_loader.py rename to python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py index dd0656c587..8d479d3ad8 100644 --- a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_loader.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py @@ -4,8 +4,9 @@ """ import functools -from ...loader import ExternMapping -from ...quantization import Quantization +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM @@ -29,8 +30,9 @@ def huggingface(model_config: GPTBigCodeConfig, quantization: Quantization) -> E model = GPTBigCodeForCausalLM(model_config) if quantization is not None: model.to(quantization.model_dtype) - _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking - spec=model.get_default_spec() + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, ) named_parameters = dict(_named_params) diff --git a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py similarity index 98% rename from python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_model.py rename to python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py index 5ff75587fb..360b638065 100644 --- a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py @@ -10,10 +10,10 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from ....support import logging -from ....support.config import ConfigBase -from ....support.style import bold -from ... import tensor_parallel as tp +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_quantization.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py similarity index 92% rename from python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_quantization.py rename to python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py index 5c9cc9c174..5063909a90 100644 --- a/python/mlc_chat/compiler/model/gpt_bigcode/gpt_bigcode_quantization.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py @@ -4,8 +4,9 @@ from tvm.relax.frontend import nn -from ...loader import QuantizeMapping -from ...quantization import AWQQuantize, GroupQuantize, NoQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, GroupQuantize, NoQuantize + from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM diff --git a/python/mlc_chat/compiler/model/gpt_neox/__init__.py b/python/mlc_chat/model/gpt_neox/__init__.py similarity index 100% rename from python/mlc_chat/compiler/model/gpt_neox/__init__.py rename to python/mlc_chat/model/gpt_neox/__init__.py diff --git a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_loader.py b/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py similarity index 90% rename from python/mlc_chat/compiler/model/gpt_neox/gpt_neox_loader.py rename to python/mlc_chat/model/gpt_neox/gpt_neox_loader.py index d421ccf660..b7e4027ce2 100644 --- a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_loader.py +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py @@ -6,8 +6,9 @@ import numpy as np -from ...loader import ExternMapping -from ...quantization import Quantization +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM @@ -31,8 +32,9 @@ def huggingface(model_config: GPTNeoXConfig, quantization: Quantization) -> Exte model = GPTNeoXForCausalLM(model_config) if quantization is not None: model.to(quantization.model_dtype) - _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking - spec=model.get_default_spec() + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, ) named_parameters = dict(_named_params) @@ -46,7 +48,7 @@ def huggingface(model_config: GPTNeoXConfig, quantization: Quantization) -> Exte mapping.add_unused(f"{attn}.bias") # change the layout of query_key_value - def transform_qkv_layout(w, dtype): + def transform_qkv_layout(w, dtype): # pylint: disable=invalid-name num_attention_heads = model_config.num_attention_heads head_dim = model_config.head_dim diff --git a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_model.py b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py similarity index 99% rename from python/mlc_chat/compiler/model/gpt_neox/gpt_neox_model.py rename to python/mlc_chat/model/gpt_neox/gpt_neox_model.py index 11ff7d005d..71693816c6 100644 --- a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py @@ -11,8 +11,8 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from ....support.config import ConfigBase -from ....support.style import bold +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_quantization.py b/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py similarity index 90% rename from python/mlc_chat/compiler/model/gpt_neox/gpt_neox_quantization.py rename to python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py index 11fc3ea61c..53bc5394b4 100644 --- a/python/mlc_chat/compiler/model/gpt_neox/gpt_neox_quantization.py +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py @@ -4,8 +4,9 @@ from tvm.relax.frontend import nn -from ...loader import QuantizeMapping -from ...quantization import GroupQuantize, NoQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import GroupQuantize, NoQuantize + from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM diff --git a/python/mlc_chat/compiler/model/llama/__init__.py b/python/mlc_chat/model/llama/__init__.py similarity index 100% rename from python/mlc_chat/compiler/model/llama/__init__.py rename to python/mlc_chat/model/llama/__init__.py diff --git a/python/mlc_chat/compiler/model/llama/llama_loader.py b/python/mlc_chat/model/llama/llama_loader.py similarity index 93% rename from python/mlc_chat/compiler/model/llama/llama_loader.py rename to python/mlc_chat/model/llama/llama_loader.py index 46d0477bbf..5dd902d04d 100644 --- a/python/mlc_chat/compiler/model/llama/llama_loader.py +++ b/python/mlc_chat/model/llama/llama_loader.py @@ -6,8 +6,9 @@ import numpy as np -from ...loader import ExternMapping -from ...quantization import Quantization +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + from .llama_model import LlamaConfig, LlamaForCasualLM from .llama_quantization import awq_quant @@ -32,8 +33,9 @@ def huggingface(model_config: LlamaConfig, quantization: Quantization) -> Extern model = LlamaForCasualLM(model_config) if quantization is not None: model.to(quantization.model_dtype) - _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking - spec=model.get_default_spec() + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, ) named_parameters = dict(_named_params) @@ -104,7 +106,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: The parameter mapping from MLC to AWQ. """ model, _ = awq_quant(model_config, quantization) - _, _named_params = model.export_tvm(spec=model.get_default_spec()) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) named_parameters = dict(_named_params) mapping = ExternMapping() diff --git a/python/mlc_chat/compiler/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py similarity index 96% rename from python/mlc_chat/compiler/model/llama/llama_model.py rename to python/mlc_chat/model/llama/llama_model.py index 9ec9475701..73826f3e7c 100644 --- a/python/mlc_chat/compiler/model/llama/llama_model.py +++ b/python/mlc_chat/model/llama/llama_model.py @@ -9,14 +9,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op +from mlc_chat import operator as op_ext from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold -from ... import tensor_parallel as tp -from .. import extern_op -from ..position_embedding_op import llama_rope - logger = logging.getLogger(__name__) @@ -120,14 +118,14 @@ def forward( # pylint: disable=too-many-locals qkv = self.qkv_proj(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) # Step 2. Apply QK rotary embedding - q, k, v = llama_rope(qkv, t, self.rope_theta, h_q, h_kv) + q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h_q, h_kv) # Step 3. Query and update KVCache self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) k = self.k_cache.view(t) v = self.v_cache.view(t) # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = extern_op.attention(q, k, v, casual_mask=attention_mask) + output = op_ext.attention(q, k, v, casual_mask=attention_mask) # Step 5. Apply output projection return self.o_proj(output) @@ -205,10 +203,7 @@ def to(self, dtype: Optional[str] = None): self.dtype = dtype def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - extern_op.configure( - rope_theta=self.rope_theta, - rope_scale=1.0, - ) + op_ext.configure() def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape diff --git a/python/mlc_chat/compiler/model/llama/llama_quantization.py b/python/mlc_chat/model/llama/llama_quantization.py similarity index 92% rename from python/mlc_chat/compiler/model/llama/llama_quantization.py rename to python/mlc_chat/model/llama/llama_quantization.py index 598c7be3fb..680d608efe 100644 --- a/python/mlc_chat/compiler/model/llama/llama_quantization.py +++ b/python/mlc_chat/model/llama/llama_quantization.py @@ -4,8 +4,9 @@ from tvm.relax.frontend import nn -from ...loader import QuantizeMapping -from ...quantization import AWQQuantize, GroupQuantize, NoQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, GroupQuantize, NoQuantize + from .llama_model import LlamaConfig, LlamaForCasualLM diff --git a/python/mlc_chat/compiler/model/mistral/__init__.py b/python/mlc_chat/model/mistral/__init__.py similarity index 100% rename from python/mlc_chat/compiler/model/mistral/__init__.py rename to python/mlc_chat/model/mistral/__init__.py diff --git a/python/mlc_chat/compiler/model/mistral/mistral_loader.py b/python/mlc_chat/model/mistral/mistral_loader.py similarity index 93% rename from python/mlc_chat/compiler/model/mistral/mistral_loader.py rename to python/mlc_chat/model/mistral/mistral_loader.py index 5f9f96b0d4..71a8f1abe9 100644 --- a/python/mlc_chat/compiler/model/mistral/mistral_loader.py +++ b/python/mlc_chat/model/mistral/mistral_loader.py @@ -6,8 +6,9 @@ import numpy as np -from ...loader import ExternMapping -from ...quantization import Quantization +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + from .mistral_model import MistralConfig, MistralForCasualLM from .mistral_quantization import awq_quant @@ -32,8 +33,9 @@ def huggingface(model_config: MistralConfig, quantization: Quantization) -> Exte model = MistralForCasualLM(model_config) if quantization is not None: model.to(quantization.model_dtype) - _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking - spec=model.get_default_spec() + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, ) named_parameters = dict(_named_params) @@ -104,7 +106,10 @@ def awq(model_config: MistralConfig, quantization: Quantization) -> ExternMappin The parameter mapping from MLC to AWQ. """ model, _ = awq_quant(model_config, quantization) - _, _named_params = model.export_tvm(spec=model.get_default_spec()) + _, _named_params = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) named_parameters = dict(_named_params) mapping = ExternMapping() diff --git a/python/mlc_chat/compiler/model/mistral/mistral_model.py b/python/mlc_chat/model/mistral/mistral_model.py similarity index 99% rename from python/mlc_chat/compiler/model/mistral/mistral_model.py rename to python/mlc_chat/model/mistral/mistral_model.py index 2f5fa2824c..8c0b45a9cf 100644 --- a/python/mlc_chat/compiler/model/mistral/mistral_model.py +++ b/python/mlc_chat/model/mistral/mistral_model.py @@ -10,9 +10,9 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from ....support import logging -from ....support.config import ConfigBase -from ....support.style import bold +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/model/mistral/mistral_quantization.py b/python/mlc_chat/model/mistral/mistral_quantization.py similarity index 91% rename from python/mlc_chat/compiler/model/mistral/mistral_quantization.py rename to python/mlc_chat/model/mistral/mistral_quantization.py index eecff1a63a..0ba59a97de 100644 --- a/python/mlc_chat/compiler/model/mistral/mistral_quantization.py +++ b/python/mlc_chat/model/mistral/mistral_quantization.py @@ -4,8 +4,9 @@ from tvm.relax.frontend import nn -from ...loader import QuantizeMapping -from ...quantization import AWQQuantize, GroupQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, GroupQuantize + from .mistral_model import MistralConfig, MistralForCasualLM diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py new file mode 100644 index 0000000000..47bc46ed06 --- /dev/null +++ b/python/mlc_chat/model/model.py @@ -0,0 +1,126 @@ +"""A centralized registry of all existing model architures and their configurations.""" +import dataclasses +from typing import Any, Callable, Dict, Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import ExternMapping, QuantizeMapping +from mlc_chat.quantization.quantization import Quantization + +from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization +from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization +from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization +from .llama import llama_loader, llama_model, llama_quantization +from .mistral import mistral_loader, mistral_model, mistral_quantization + +ModelConfig = Any +"""A ModelConfig is an object that represents a model architecture. It is required to have +a class method `from_file` with the following signature: + + def from_file(cls, path: Path) -> ModelConfig: + ... +""" + +FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping] +FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]] + + +@dataclasses.dataclass +class Model: + """All about a model architecture: its configuration, its parameter loader and quantization. + + Parameters + ---------- + name : str + The name of the model. + + model : Callable[[ModelConfig], nn.Module] + A method that creates the `nn.Module` that represents the model from `ModelConfig`. + + config : ModelConfig + A class that has a `from_file` class method, whose signature is "Path -> ModelConfig". + + source : Dict[str, FuncGetExternMap] + A dictionary that maps the name of a source format to parameter mapping. + + quantize: Dict[str, FuncQuantization] + A dictionary that maps the name of a quantization method to quantized model and the + quantization parameter mapping. + """ + + name: str + config: ModelConfig + model: Callable[[ModelConfig], nn.Module] + source: Dict[str, FuncGetExternMap] + quantize: Dict[str, FuncQuantization] + + +MODELS: Dict[str, Model] = { + "llama": Model( + name="llama", + model=llama_model.LlamaForCasualLM, + config=llama_model.LlamaConfig, + source={ + "huggingface-torch": llama_loader.huggingface, + "huggingface-safetensor": llama_loader.huggingface, + "awq": llama_loader.awq, + }, + quantize={ + "no-quant": llama_quantization.no_quant, + "group-quant": llama_quantization.group_quant, + "awq": llama_quantization.awq_quant, + }, + ), + "mistral": Model( + name="mistral", + model=mistral_model.MistralForCasualLM, + config=mistral_model.MistralConfig, + source={ + "huggingface-torch": mistral_loader.huggingface, + "huggingface-safetensor": mistral_loader.huggingface, + "awq": mistral_loader.awq, + }, + quantize={ + "group-quant": mistral_quantization.group_quant, + }, + ), + "gpt2": Model( + name="gpt2", + model=gpt2_model.GPT2LMHeadModel, + config=gpt2_model.GPT2Config, + source={ + "huggingface-torch": gpt2_loader.huggingface, + "huggingface-safetensor": gpt2_loader.huggingface, + }, + quantize={ + "no-quant": gpt2_quantization.no_quant, + "group-quant": gpt2_quantization.group_quant, + }, + ), + "gpt_neox": Model( + name="gpt_neox", + model=gpt_neox_model.GPTNeoXForCausalLM, + config=gpt_neox_model.GPTNeoXConfig, + source={ + "huggingface-torch": gpt_neox_loader.huggingface, + "huggingface-safetensor": gpt_neox_loader.huggingface, + }, + quantize={ + "no-quant": gpt_neox_quantization.no_quant, + "group-quant": gpt_neox_quantization.group_quant, + }, + ), + "gpt_bigcode": Model( + name="gpt_bigcode", + model=gpt_bigcode_model.GPTBigCodeForCausalLM, + config=gpt_bigcode_model.GPTBigCodeConfig, + source={ + "huggingface-torch": gpt_bigcode_loader.huggingface, + "huggingface-safetensor": gpt_bigcode_loader.huggingface, + }, + quantize={ + "no-quant": gpt_bigcode_quantization.no_quant, + "group-quant": gpt_bigcode_quantization.group_quant, + }, + ), +} diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/model/model_preset.py similarity index 61% rename from python/mlc_chat/compiler/model/model.py rename to python/mlc_chat/model/model_preset.py index e3c737c79d..371f8a0f8b 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/model/model_preset.py @@ -1,128 +1,5 @@ -"""A centralized registry of all existing model architures and their configurations.""" -import dataclasses -from typing import Any, Callable, Dict, Tuple - -from tvm.relax.frontend import nn - -from ..loader import ExternMapping, QuantizeMapping -from ..quantization.quantization import Quantization -from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization -from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization -from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization -from .llama import llama_loader, llama_model, llama_quantization -from .mistral import mistral_loader, mistral_model, mistral_quantization - -ModelConfig = Any -"""A ModelConfig is an object that represents a model architecture. It is required to have -a class method `from_file` with the following signature: - - def from_file(cls, path: Path) -> ModelConfig: - ... -""" - -FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping] -FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]] - - -@dataclasses.dataclass -class Model: - """All about a model architecture: its configuration, its parameter loader and quantization. - - Parameters - ---------- - name : str - The name of the model. - - model : Callable[[ModelConfig], nn.Module] - A method that creates the `nn.Module` that represents the model from `ModelConfig`. - - config : ModelConfig - A class that has a `from_file` class method, whose signature is "Path -> ModelConfig". - - source : Dict[str, FuncGetExternMap] - A dictionary that maps the name of a source format to parameter mapping. - - quantize: Dict[str, FuncQuantization] - A dictionary that maps the name of a quantization method to quantized model and the - quantization parameter mapping. - """ - - name: str - config: ModelConfig - model: Callable[[ModelConfig], nn.Module] - source: Dict[str, FuncGetExternMap] - quantize: Dict[str, FuncQuantization] - - -MODELS: Dict[str, Model] = { - "llama": Model( - name="llama", - model=llama_model.LlamaForCasualLM, - config=llama_model.LlamaConfig, - source={ - "huggingface-torch": llama_loader.huggingface, - "huggingface-safetensor": llama_loader.huggingface, - "awq": llama_loader.awq, - }, - quantize={ - "no-quant": llama_quantization.no_quant, - "group-quant": llama_quantization.group_quant, - "awq": llama_quantization.awq_quant, - }, - ), - "mistral": Model( - name="mistral", - model=mistral_model.MistralForCasualLM, - config=mistral_model.MistralConfig, - source={ - "huggingface-torch": mistral_loader.huggingface, - "huggingface-safetensor": mistral_loader.huggingface, - "awq": mistral_loader.awq, - }, - quantize={ - "group-quant": mistral_quantization.group_quant, - }, - ), - "gpt2": Model( - name="gpt2", - model=gpt2_model.GPT2LMHeadModel, - config=gpt2_model.GPT2Config, - source={ - "huggingface-torch": gpt2_loader.huggingface, - "huggingface-safetensor": gpt2_loader.huggingface, - }, - quantize={ - "no-quant": gpt2_quantization.no_quant, - "group-quant": gpt2_quantization.group_quant, - }, - ), - "gpt_neox": Model( - name="gpt_neox", - model=gpt_neox_model.GPTNeoXForCausalLM, - config=gpt_neox_model.GPTNeoXConfig, - source={ - "huggingface-torch": gpt_neox_loader.huggingface, - "huggingface-safetensor": gpt_neox_loader.huggingface, - }, - quantize={ - "no-quant": gpt_neox_quantization.no_quant, - "group-quant": gpt_neox_quantization.group_quant, - }, - ), - "gpt_bigcode": Model( - name="gpt_bigcode", - model=gpt_bigcode_model.GPTBigCodeForCausalLM, - config=gpt_bigcode_model.GPTBigCodeConfig, - source={ - "huggingface-torch": gpt_bigcode_loader.huggingface, - "huggingface-safetensor": gpt_bigcode_loader.huggingface, - }, - quantize={ - "no-quant": gpt_bigcode_quantization.no_quant, - "group-quant": gpt_bigcode_quantization.group_quant, - }, - ), -} +"""A builtin set of models available in MLC LLM.""" +from typing import Any, Dict MODEL_PRESETS: Dict[str, Any] = { "llama2_7b": { diff --git a/python/mlc_chat/compiler/extern/__init__.py b/python/mlc_chat/operator/__init__.py similarity index 52% rename from python/mlc_chat/compiler/extern/__init__.py rename to python/mlc_chat/operator/__init__.py index 310c77d91a..1fe98917cc 100644 --- a/python/mlc_chat/compiler/extern/__init__.py +++ b/python/mlc_chat/operator/__init__.py @@ -1,2 +1,4 @@ """Extern module for compiler.""" +from .attention import attention from .extern import configure, enable, get_store +from .position_embedding import llama_rope diff --git a/python/mlc_chat/compiler/model/extern_op.py b/python/mlc_chat/operator/attention.py similarity index 93% rename from python/mlc_chat/compiler/model/extern_op.py rename to python/mlc_chat/operator/attention.py index e9cfb48ba9..7f1cadf778 100644 --- a/python/mlc_chat/compiler/model/extern_op.py +++ b/python/mlc_chat/operator/attention.py @@ -5,13 +5,7 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import op -from .. import extern as _extern -from ..extern import configure - -__all__ = [ - "attention", - "configure", -] +from . import extern as _extern def attention( # pylint: disable=invalid-name,too-many-locals @@ -63,8 +57,8 @@ def attention( # pylint: disable=invalid-name,too-many-locals and k.dtype == "float16" and v.dtype == "float16" ): - rope_scale = extern_store.flashinfer.rope_scale - rope_theta = extern_store.flashinfer.rope_theta + rope_theta = 0.0 + rope_scale = 1.0 qkv_layout = 0 # "NHD", N for seq_len, H for num_heads, D for head_dim rotary_mode = 0 # "kNone" casual = 1 # True diff --git a/python/mlc_chat/compiler/extern/extern.py b/python/mlc_chat/operator/extern.py similarity index 79% rename from python/mlc_chat/compiler/extern/extern.py rename to python/mlc_chat/operator/extern.py index c63ec4a624..f48c88ac8d 100644 --- a/python/mlc_chat/compiler/extern/extern.py +++ b/python/mlc_chat/operator/extern.py @@ -19,8 +19,6 @@ from tvm.target import Target -from .flashinfer import FlashInfer - @dataclasses.dataclass class ExternModuleStore: @@ -28,7 +26,7 @@ class ExternModuleStore: configured: bool = False target: Optional[Target] = None - flashinfer: Optional[FlashInfer] = None + flashinfer: bool = False STORE: ExternModuleStore = ExternModuleStore() @@ -39,8 +37,9 @@ def enable(target: Target, flashinfer: bool) -> None: """Enable external modules. It should be called before any compilation happens.""" global STORE # pylint: disable=global-statement STORE = ExternModuleStore( + configured=False, target=target, - flashinfer=FlashInfer() if flashinfer else None, + flashinfer=flashinfer, ) @@ -49,25 +48,16 @@ def get_store() -> ExternModuleStore: return STORE -def configure(rope_scale: float, rope_theta: float) -> None: +def configure() -> None: """Configure external modules with extra parameters. It should be called during a model's `forward` method is invoked. Parameters ---------- - rope_scale : float - Scaling factor for the RoPE embedding. 1.0 by default. - - rope_theta : float - The base period of the RoPE embedding. 10000.0 by default. """ store = get_store() if store.configured: return store.configured = True - if store.flashinfer is not None: + if store.flashinfer: assert store.target.kind.name == "cuda" - store.flashinfer.configure( - rope_scale=rope_scale, - rope_theta=rope_theta, - ) diff --git a/python/mlc_chat/compiler/model/position_embedding_op.py b/python/mlc_chat/operator/position_embedding.py similarity index 100% rename from python/mlc_chat/compiler/model/position_embedding_op.py rename to python/mlc_chat/operator/position_embedding.py diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/quantization/__init__.py similarity index 100% rename from python/mlc_chat/compiler/quantization/__init__.py rename to python/mlc_chat/quantization/__init__.py diff --git a/python/mlc_chat/compiler/quantization/awq_quantization.py b/python/mlc_chat/quantization/awq_quantization.py similarity index 99% rename from python/mlc_chat/compiler/quantization/awq_quantization.py rename to python/mlc_chat/quantization/awq_quantization.py index 58e5cc1583..22144e1483 100644 --- a/python/mlc_chat/compiler/quantization/awq_quantization.py +++ b/python/mlc_chat/quantization/awq_quantization.py @@ -1,5 +1,4 @@ """AWQ Quantization""" - from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional @@ -7,7 +6,8 @@ from tvm.relax.frontend import nn from tvm.runtime import NDArray -from ..loader import QuantizeMapping +from mlc_chat.loader import QuantizeMapping + from .utils import convert_uint_to_float diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/quantization/group_quantization.py similarity index 99% rename from python/mlc_chat/compiler/quantization/group_quantization.py rename to python/mlc_chat/quantization/group_quantization.py index 58d77c5245..5865af87f6 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/quantization/group_quantization.py @@ -9,9 +9,10 @@ from tvm.runtime import NDArray from tvm.target import Target -from ...support import logging -from .. import tensor_parallel as tp -from ..loader import QuantizeMapping +from mlc_chat.loader import QuantizeMapping +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp + from .utils import convert_uint_to_float logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler/quantization/no_quantization.py b/python/mlc_chat/quantization/no_quantization.py similarity index 100% rename from python/mlc_chat/compiler/quantization/no_quantization.py rename to python/mlc_chat/quantization/no_quantization.py diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/quantization/quantization.py similarity index 100% rename from python/mlc_chat/compiler/quantization/quantization.py rename to python/mlc_chat/quantization/quantization.py diff --git a/python/mlc_chat/compiler/quantization/utils.py b/python/mlc_chat/quantization/utils.py similarity index 100% rename from python/mlc_chat/compiler/quantization/utils.py rename to python/mlc_chat/quantization/utils.py diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index a11ed375bd..146f67a85e 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -14,8 +14,8 @@ from fastapi.responses import StreamingResponse from mlc_chat.chat_module import GenerationConfig +from mlc_chat.support.random import set_global_random_seed -from .base import set_global_random_seed from .chat_module import ChatModule from .interface.openai_api import ( ChatCompletionRequest, diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 12d7d10465..7502f41962 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -4,14 +4,12 @@ from pathlib import Path from typing import TYPE_CHECKING -from mlc_chat.compiler import QUANTIZATION, Quantization - from . import logging -from .download import download_mlc_weights from .style import bold, green if TYPE_CHECKING: - from mlc_chat.compiler import Model # pylint: disable=unused-import + from mlc_chat.model import Model # pylint: disable=unused-import + from mlc_chat.quantization import Quantization # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -34,9 +32,12 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: mlc_chat_config_json_path : pathlib.Path The path points to mlc_chat_config.json. """ - from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel - MODEL_PRESETS, - ) + # pylint: disable=import-outside-toplevel + from mlc_chat.model import MODEL_PRESETS + + from .download import download_mlc_weights + + # pylint: enable=import-outside-toplevel if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"): mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config)) @@ -84,9 +85,7 @@ def detect_config(config: str) -> Path: config_json_path : pathlib.Path The path points to config.json. """ - from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel - MODEL_PRESETS, - ) + from mlc_chat.model import MODEL_PRESETS # pylint: disable=import-outside-toplevel if isinstance(config, str) and config in MODEL_PRESETS: logger.info("%s preset model: %s", FOUND, config) @@ -136,10 +135,7 @@ def detect_model_type(model_type: str, config: Path) -> "Model": The model type. """ - from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel - MODELS, - Model, - ) + from mlc_chat.model import MODELS, Model # pylint: disable=import-outside-toplevel if model_type == "auto": with open(config, "r", encoding="utf-8") as config_file: @@ -158,7 +154,7 @@ def detect_model_type(model_type: str, config: Path) -> "Model": return MODELS[model_type] -def detect_quantization(quantization_arg: str, config: Path) -> Quantization: +def detect_quantization(quantization_arg: str, config: Path) -> "Quantization": """Detect the model quantization scheme from the configuration file or `--quantization` argument. If `--quantization` is provided, it will override the value on the configuration file. @@ -173,13 +169,15 @@ def detect_quantization(quantization_arg: str, config: Path) -> Quantization: Returns ------- - quantization : mlc_chat.compiler.Quantization + quantization : mlc_chat.quantization.Quantization The model quantization scheme. """ + from mlc_chat.quantization import ( # pylint: disable=import-outside-toplevel + QUANTIZATION, + ) with open(config, "r", encoding="utf-8") as config_file: cfg = json.load(config_file) - if quantization_arg is not None: quantization = QUANTIZATION[quantization_arg] elif "quantization" in cfg: @@ -189,5 +187,4 @@ def detect_quantization(quantization_arg: str, config: Path) -> Quantization: f"'quantization' not found in: {config}. " f"Please explicitly specify `--quantization` instead." ) - return quantization diff --git a/python/mlc_chat/support/random.py b/python/mlc_chat/support/random.py new file mode 100644 index 0000000000..0568276d12 --- /dev/null +++ b/python/mlc_chat/support/random.py @@ -0,0 +1,16 @@ +"""Utility functions for random number generation.""" +import sys + + +def set_global_random_seed(seed): + """Set global random seed for python, numpy, torch and tvm.""" + if "numpy" in sys.modules: + sys.modules["numpy"].random.seed(seed) + if "torch" in sys.modules: + sys.modules["torch"].manual_seed(seed) + if "random" in sys.modules: + sys.modules["random"].seed(seed) + if "tvm" in sys.modules: + set_seed = sys.modules["tvm"].get_global_func("mlc.random.set_seed") + if set_seed: + set_seed(seed) diff --git a/python/mlc_chat/compiler/tensor_parallel.py b/python/mlc_chat/support/tensor_parallel.py similarity index 100% rename from python/mlc_chat/compiler/tensor_parallel.py rename to python/mlc_chat/support/tensor_parallel.py diff --git a/tests/python/loader/test_awq.py b/tests/python/loader/test_awq.py index 43adb0d214..d945a95db0 100644 --- a/tests/python/loader/test_awq.py +++ b/tests/python/loader/test_awq.py @@ -5,8 +5,9 @@ import pytest import tvm -from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION -from mlc_chat.compiler.loader import HuggingFaceLoader +from mlc_chat.loader import HuggingFaceLoader +from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_chat.quantization import QUANTIZATION from mlc_chat.support import logging, tqdm logging.enable_logging() diff --git a/tests/python/loader/test_huggingface.py b/tests/python/loader/test_huggingface.py index 57bcd851b9..dfbef55c28 100644 --- a/tests/python/loader/test_huggingface.py +++ b/tests/python/loader/test_huggingface.py @@ -5,8 +5,8 @@ import pytest import tvm -from mlc_chat.compiler import MODELS -from mlc_chat.compiler.loader import HuggingFaceLoader +from mlc_chat.loader import HuggingFaceLoader +from mlc_chat.model import MODELS from mlc_chat.support import logging, tqdm logging.enable_logging() diff --git a/tests/python/model/test_gpt2.py b/tests/python/model/test_gpt2.py index 8ca00d5009..9517ad1c45 100644 --- a/tests/python/model/test_gpt2.py +++ b/tests/python/model/test_gpt2.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODEL_PRESETS, MODELS +from mlc_chat.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["gpt2"]) diff --git a/tests/python/model/test_gptNeox.py b/tests/python/model/test_gptNeox.py index e3ae0ae81a..d4fcfdd142 100644 --- a/tests/python/model/test_gptNeox.py +++ b/tests/python/model/test_gptNeox.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODEL_PRESETS, MODELS +from mlc_chat.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["redpajama_3b_v1"]) diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 8bbbd75971..8ea682f7f0 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODEL_PRESETS, MODELS +from mlc_chat.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 21f1392449..4d4c761fb1 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -1,8 +1,9 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION -from mlc_chat.compiler.quantization.group_quantization import ( +from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.quantization.group_quantization import ( GroupQuantizeEmbedding, GroupQuantizeLinear, ) diff --git a/tests/python/model/test_mistral.py b/tests/python/model/test_mistral.py index cb5b3a3320..631b592979 100644 --- a/tests/python/model/test_mistral.py +++ b/tests/python/model/test_mistral.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODEL_PRESETS, MODELS +from mlc_chat.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["mistral_7b"]) diff --git a/tests/python/quantization/test_awq_quantization.py b/tests/python/quantization/test_awq_quantization.py index fbdb680cb0..244271aff7 100644 --- a/tests/python/quantization/test_awq_quantization.py +++ b/tests/python/quantization/test_awq_quantization.py @@ -9,9 +9,8 @@ from tvm import DataType from tvm.relax.frontend import nn -from mlc_chat.compiler import QUANTIZATION -from mlc_chat.compiler.loader import QuantizeMapping -from mlc_chat.compiler.quantization import AWQQuantize +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import QUANTIZATION, AWQQuantize def dequantize_np( diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 3956c8aaa2..72133ff013 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -9,8 +9,9 @@ from tvm import DataType from tvm.relax.frontend import nn -from mlc_chat.compiler import QUANTIZATION, QuantizeMapping -from mlc_chat.compiler.quantization.group_quantization import ( +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import QUANTIZATION +from mlc_chat.quantization.group_quantization import ( GroupQuantize, GroupQuantizeEmbedding, GroupQuantizeLinear,