Skip to content

Commit

Permalink
Minor enhancements to ChatModule (mlc-ai#1132)
Browse files Browse the repository at this point in the history
Some minor enhancements to `ChatModule`, mainly handle the device parsing solely in `_parse_device_str` instead of handling it both in the member function and the `__init__` function to avoid redundancy; and some type annotation fix.
  • Loading branch information
YuchenJin authored Oct 28, 2023
1 parent 2c492e5 commit 2ec0cc8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class LLMChat {

/*!
* \brief Reload model, tokenizers and configurations from the specified model path.
* \param executable The module to reload.
* \param reload_lib The module to reload, it can either be a path to the library or a tvm Module.
* \param model_path The path to search for models.
* \param app_config_json The JSON string used to partially override the configuration loaded from
* disk, default to empty string.
Expand Down
82 changes: 42 additions & 40 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,14 +532,16 @@ def _get_lib_module_path(
raise FileNotFoundError(err_msg)


def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_template: str) -> str:
def _convert_chat_config_to_json_str(
chat_config: Optional[ChatConfig], conv_template: Optional[str]
) -> str:
"""Convert user's input ChatConfig to a json string, omitting ``None`` fields.
Parameters
----------
chat_config : Optional[ChatConfig]
User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``.
conv_template : str
conv_template : Optional[str]
The ``conv_template`` that will be used after considering potential override.
Returns
Expand Down Expand Up @@ -591,7 +593,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
return json.dumps(asdict(generation_config))


def _parse_device_str(device: str):
def _parse_device_str(device: str) -> (tvm.runtime.Device, str):
"""Parse the input device identifier into device name and id.
Parameters
Expand All @@ -603,11 +605,11 @@ def _parse_device_str(device: str):
Returns
-------
dev : tvm.runtime.Device
The device.
device_name : str
The name of the device.
device_id : int
The id of the device, or 0 if not specified in the input.
"""
device_err_msg = (
f"Invalid device name: {device}. Please enter the device in the form "
Expand All @@ -616,14 +618,32 @@ def _parse_device_str(device: str):
)
device_args = device.split(":")
if len(device_args) == 1:
return device_args[0], 0
device_name, device_id = device_args[0], 0
elif len(device_args) == 2:
return device_args[0], int(device_args[1])
device_name, device_id = device_args[0], int(device_args[1])
elif len(device_args) > 2:
raise ValueError(device_err_msg)

if device_name == "cuda":
device = tvm.cuda(device_id)
elif device_name == "metal":
device = tvm.metal(device_id)
elif device_name == "vulkan":
device = tvm.vulkan(device_id)
elif device_name == "rocm":
device = tvm.rocm(device_id)
elif device_name == "opencl":
device = tvm.opencl(device_id)
elif device_name == "auto":
device, device_name = _detect_local_device(device_id)
logging.info(f"System automatically detected device: {device_name}")
else:
raise ValueError(device_err_msg)

return device, device_name

def _detect_local_device(device_id: int = 0):

def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str):
"""Automatically detect the local device if user does not specify.
Parameters
Expand All @@ -633,8 +653,11 @@ def _detect_local_device(device_id: int = 0):
Returns
------
dev : Device
dev : tvm.runtime.Device
The local device.
device_name : str
The name of the device.
"""
if tvm.metal().exist:
return tvm.metal(device_id), "metal"
Expand Down Expand Up @@ -715,34 +738,13 @@ def __init__(
chat_config: Optional[ChatConfig] = None,
model_lib_path: Optional[str] = None,
):
device_err_msg = (
f"Invalid device name: {device}. Please enter the device in the form "
"'device_name:device_id' or 'device_name', where 'device_name' needs to be "
"one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
)

# 0. Retrieve device_name and device_id (if any, default 0) from device arg
device_name, device_id = _parse_device_str(device)

# 1. Get self.device
if device_name == "cuda":
self.device = tvm.cuda(device_id)
elif device_name == "metal":
self.device = tvm.metal(device_id)
elif device_name == "vulkan":
self.device = tvm.vulkan(device_id)
elif device_name == "rocm":
self.device = tvm.rocm(device_id)
elif device_name == "opencl":
self.device = tvm.opencl(device_id)
elif device_name == "auto":
self.device, device_name = _detect_local_device(device_id)
logging.info(f"System automatically detected device: {device_name}")
else:
raise ValueError(device_err_msg)
# 0. Get device:
# Retrieve device_name and device_id (if any, default 0) from device arg
self.device, device_name = _parse_device_str(device)
device_type = self.device.device_type
device_id = self.device.device_id

# 2. Populate chat module and their functions
# 1. Populate chat module and their functions
fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create")
assert fcreate_chat_mod is not None
chat_mod = fcreate_chat_mod(device_type, device_id)
Expand All @@ -768,13 +770,13 @@ def __init__(
self._get_role0_func = chat_mod["get_role0"]
self._get_role1_func = chat_mod["get_role1"]

# 3. Look up model_path
# 2. Look up model_path
self.model_path, self.config_file_path = _get_model_path(model)

# 4. Instantiate chat_config
# 3. Instantiate chat_config
self.chat_config = _get_chat_config(self.config_file_path, chat_config)

# 5. Look up model library
# 4. Look up model library
self.model_lib_path = _get_lib_module_path(
model,
self.model_path,
Expand All @@ -784,7 +786,7 @@ def __init__(
self.config_file_path,
)

# 6. Call reload
# 5. Call reload
user_chat_config_json_str = _convert_chat_config_to_json_str(
self.chat_config, self.chat_config.conv_template
)
Expand Down

0 comments on commit 2ec0cc8

Please sign in to comment.