Skip to content

Commit

Permalink
Reorganize folder structure (mlc-ai#1502)
Browse files Browse the repository at this point in the history
This PR reorganizes the existing `python/` folder structure for better clarify.

- `mlc_chat/model` <- `mlc_chat/compiler/model`
- `mlc_chat/quantization` <- `mlc_chat/compiler/quantization`
- `mlc_chat/loader` <- `mlc_chat/compiler/loader`
- `mlc_chat/operator` <- `mlc_chat/compiler/*_op.py`
- `mlc_chat/compiler_pass` <- `mlc_chat/compiler/compiler_pass.py`
- `mlc_chat/interface` <- `mlc_chat/compiler/{compile/gen_config/convert_weight}.py`
  • Loading branch information
junrushao authored Dec 28, 2023
1 parent 09ec207 commit 779b1a5
Show file tree
Hide file tree
Showing 77 changed files with 429 additions and 470 deletions.
36 changes: 1 addition & 35 deletions python/mlc_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _load_mlc_llm_lib():
"""Load mlc llm lib"""
"""Load MLC LLM lib"""
if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
for path in libinfo.get_dll_directories():
os.add_dll_directory(path)
Expand All @@ -24,37 +24,3 @@ def _load_mlc_llm_lib():
# only load once here
if os.environ.get("SKIP_LOADING_MLCLLM_SO", "0") == "0":
_LIB, _LIB_PATH = _load_mlc_llm_lib()


def get_delta_message(curr_message: str, new_message: str) -> str:
r"""Given the current message and the new message, compute the delta message
(the newly generated part, the diff of the new message from the current message).
Parameters
----------
curr_message : str
The message generated in the previous round.
new_message : str
The message generated in the new round.
Returns
-------
delta_message : str
The diff of the new message from the current message (the newly generated part).
"""
f_get_delta_message = tvm.get_global_func("mlc.get_delta_message")
return f_get_delta_message(curr_message, new_message)


def set_global_random_seed(seed):
"""Set global random seed for python, numpy, torch and tvm."""
if "numpy" in sys.modules:
sys.modules["numpy"].random.seed(seed)
if "torch" in sys.modules:
sys.modules["torch"].manual_seed(seed)
if "random" in sys.modules:
sys.modules["random"].seed(seed)
if "tvm" in sys.modules:
set_seed = sys.modules["tvm"].get_global_func("mlc.random.set_seed")
if set_seed:
set_seed(seed)
24 changes: 22 additions & 2 deletions python/mlc_chat/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,27 @@
from queue import Queue
from typing import Optional

from .base import get_delta_message

def _get_delta_message(curr_message: str, new_message: str) -> str:
r"""Given the current message and the new message, compute the delta message
(the newly generated part, the diff of the new message from the current message).
Parameters
----------
curr_message : str
The message generated in the previous round.
new_message : str
The message generated in the new round.
Returns
-------
delta_message : str
The diff of the new message from the current message (the newly generated part).
"""
from tvm._ffi import get_global_func # pylint: disable=import-outside-toplevel

f_get_delta_message = get_global_func("mlc.get_delta_message")
return f_get_delta_message(curr_message, new_message)


class DeltaCallback:
Expand All @@ -27,7 +47,7 @@ def __call__(self, message: str = "", stopped: bool = False):
self.stopped_callback()
self.curr_message = ""
else:
delta = get_delta_message(self.curr_message, message)
delta = _get_delta_message(self.curr_message, message)
self.curr_message = message
self.delta_callback(delta)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mlc_chat.support import logging
from mlc_chat.support.auto_device import detect_device

from . import base # pylint: disable=unused-import
from . import base as _

if TYPE_CHECKING:
from .interface.openai_api import ChatMessage
Expand Down
18 changes: 10 additions & 8 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
from pathlib import Path
from typing import Union

from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
HELP,
MODELS,
QUANTIZATION,
from mlc_chat.help import HELP
from mlc_chat.interface.compile import ( # pylint: disable=redefined-builtin
ModelConfigOverride,
OptimizationFlags,
compile,
)

from ..support.argparse import ArgumentParser
from ..support.auto_config import (
from mlc_chat.model import MODELS
from mlc_chat.quantization import QUANTIZATION
from mlc_chat.support.argparse import ArgumentParser
from mlc_chat.support.auto_config import (
detect_mlc_chat_config,
detect_model_type,
detect_quantization,
)
from ..support.auto_target import detect_system_lib_prefix, detect_target_and_host
from mlc_chat.support.auto_target import (
detect_system_lib_prefix,
detect_target_and_host,
)


def main(argv):
Expand Down
14 changes: 8 additions & 6 deletions python/mlc_chat/cli/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from pathlib import Path
from typing import Union

from mlc_chat.compiler import HELP, MODELS, QUANTIZATION, convert_weight

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_device import detect_device
from ..support.auto_weight import detect_weight
from mlc_chat.help import HELP
from mlc_chat.interface.convert_weight import convert_weight
from mlc_chat.model import MODELS
from mlc_chat.quantization import QUANTIZATION
from mlc_chat.support.argparse import ArgumentParser
from mlc_chat.support.auto_config import detect_config, detect_model_type
from mlc_chat.support.auto_device import detect_device
from mlc_chat.support.auto_weight import detect_weight


def main(argv):
Expand Down
10 changes: 6 additions & 4 deletions python/mlc_chat/cli/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from pathlib import Path
from typing import Union

from mlc_chat.compiler import CONV_TEMPLATES, HELP, MODELS, QUANTIZATION, gen_config

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from mlc_chat.help import HELP
from mlc_chat.interface.gen_config import CONV_TEMPLATES, gen_config
from mlc_chat.model import MODELS
from mlc_chat.quantization import QUANTIZATION
from mlc_chat.support.argparse import ArgumentParser
from mlc_chat.support.auto_config import detect_config, detect_model_type


def main(argv):
Expand Down
14 changes: 0 additions & 14 deletions python/mlc_chat/compiler/__init__.py

This file was deleted.

28 changes: 0 additions & 28 deletions python/mlc_chat/compiler/extern/flashinfer.py

This file was deleted.

96 changes: 0 additions & 96 deletions python/mlc_chat/compiler/flags_optimization.py

This file was deleted.

3 changes: 0 additions & 3 deletions python/mlc_chat/compiler/model/__init__.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _resolve_tir_var_mapping( # pylint: disable=too-many-locals

ret_tensors = call.sinfo_args[0]
ret_tensors = (
[ret_tensors]
[ret_tensors] # type: ignore[assignment]
if isinstance(ret_tensors, relax.TensorStructInfo)
else list(ret_tensors.fields)
)
Expand Down
File renamed without changes.
7 changes: 1 addition & 6 deletions python/mlc_chat/compiler/help.py → python/mlc_chat/help.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Help message for CLI arguments."""
from .model import MODEL_PRESETS

HELP = {
"config": (
"""
Expand All @@ -17,10 +15,7 @@
as well as an optional `generation_config.json` provides additional default configuration for
text generation.
Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main.
Pre-defined model architectures include """
+ ", ".join(f'"{preset}"' for preset in MODEL_PRESETS)
+ "."
"""
).strip(),
"quantization": """
The quantization mode we use to compile. If unprovided, will infer from `model`.
Expand Down
Loading

0 comments on commit 779b1a5

Please sign in to comment.