Skip to content

Commit

Permalink
[Slim-LM] Enable Group Quant (mlc-ai#1129)
Browse files Browse the repository at this point in the history
* Enable group quant via new interface.

* Minor fix.

* Linting.

* Fix isort.

* Fix mypy.

* TE compute working.

* Skip embed.

* Support cpu+gpu quantization.

* Add target option to tests.

* Linting.
  • Loading branch information
zxybazh authored Oct 29, 2023
1 parent 878ae84 commit c0c3a8d
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 8 deletions.
101 changes: 101 additions & 0 deletions python/mlc_chat/compiler/model/llama_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Quantization specs for Llama2 architecture.
TODO: add docstring
"""
from typing import Callable, Dict, List, Optional

import tvm
from tvm.runtime import NDArray

from ..parameter import QuantizeMapping
from ..quantization import QuantizeConfig
from ..quantization.group_quantizer import te_quantize as te_group_quantize
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM


def huggingface_group_quantize(
model_config: LlamaConfig,
quantize_config: QuantizeConfig,
target: Optional[tvm.target.Target] = None,
) -> QuantizeMapping:
"""Returns a parameter mapping that maps a parameter in MLC LLM's model
definition to its eventual names and values after quantization.
Parameters
----------
model_config : LlamaConfig
The configuration of the Llama model.
quantize_config : GroupQuantizeConfig
The configuration of the group quantization.
target : Optional[tvm.target.Target]
The target device to run the quantization on, by default None, which
means the quantization will be run on CPU.
Returns
-------
quantize_map : QuantizeMapping
The parameter mapping from a parameter in MLC LLM's model definition to
its eventual names and values after quantization.
"""

def group_quantize(
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None
):
if target is None or target.kind.name == "llvm":
target = tvm.target.Target("llvm")
device = tvm.cpu()
elif target.kind.name == "cuda":
device = tvm.cuda()
else:
raise ValueError(f"Invalid target device: {target}")
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param")
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore
param_tensor, config
)
s = tvm.te.create_schedule(
[compute.op for compute in [weight_compute, scale_compute] + other_computes]
)
if target.kind.name == "cuda":
# thread_binding for cuda
for compute in [weight_compute, scale_compute] + other_computes:
xo, xi = s[compute].split(compute.op.axis[0], 256)
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x"))
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x"))
f_quantize = tvm.build(
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target
)
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device)
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device)
f_quantize(param.copyto(device), weight, scale)
return weight, scale

# Param check
assert (
quantize_config.kind == "group_quantize"
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}"
assert (
quantize_config.name == "q4f16_1"
), """Only support q4f16_1 quantization scheme for now."""

# Fetch model parameter & names
model = LlamaForCasualLM(model_config)
_, named_params = model.export_tvm(spec=model.get_default_spec())
parameter_names = {name for name, _ in named_params}

# Init mappings
param_map: Dict[str, List[str]] = {}
map_func: Dict[str, Callable] = {}

# Dispatch quantization scheme
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py
for name in parameter_names:
if "norm.weight" not in name and "embed" not in name:
param_map[name] = [f"{name}_quantized", f"{name}_scale"]
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target)
else:
# skip these parameters
param_map[name] = [name]
map_func[name] = lambda x: [x]

return QuantizeMapping(param_map, map_func)
39 changes: 34 additions & 5 deletions python/mlc_chat/compiler/parameter/huggingface_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
import logging
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Optional, Tuple

import numpy as np
from tqdm import tqdm
from tvm.runtime import NDArray
from tvm.runtime.ndarray import array as as_ndarray

from .mapping import ExternMapping
from .mapping import ExternMapping, QuantizeMapping
from .stats import Stats
from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard
from .utils import (
ParamQuantizer,
check_parameter_usage,
load_safetensor_shard,
load_torch_shard,
)

logger = logging.getLogger(__name__)

Expand All @@ -38,17 +43,22 @@ class HuggingFaceLoader: # pylint: disable=too-few-public-methods
cached_files : Dict[Path, Dict[str, np.ndarray]]
A cache of the loaded files. The key is the path of the file, and the value is a mapping
from parameter name to the parameter value.
quantize_param_map : Optional[QuantizeMapping]
The quantization mapping from MLC to quantized MLC parameters.
"""

stats: Stats
extern_param_map: ExternMapping
cached_files: Dict[Path, Dict[str, np.ndarray]]
torch_to_path: Dict[str, Path]
extern_param_map: ExternMapping
quantize_param_map: Optional[QuantizeMapping]

def __init__(
self,
path: Path,
extern_param_map: ExternMapping,
quantize_param_map: Optional[QuantizeMapping] = None,
) -> None:
"""Create a parameter loader from HuggingFace PyTorch format.
Expand All @@ -66,12 +76,17 @@ def __init__(
extern_param_map : ExternMapping
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.
quantize_param_map: Optional[QuantizeMapping]
The quantization mapping from MLC to quantized MLC parameters, default to None, which
means no quantization.
"""
assert path.is_file()
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
self.torch_to_path = {}
self.quantize_param_map = quantize_param_map
if path.suffix in (".bin", ".safetensors"):
self._load_file(path)
for name in self.cached_files[path].keys():
Expand All @@ -90,7 +105,21 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
mlc_names = _loading_order(self.extern_param_map, self.torch_to_path)
for mlc_name in tqdm(mlc_names):
param = self._load_mlc_param(mlc_name)
yield mlc_name, param
if self.quantize_param_map:
with self.stats.timer("quant_time_sec"):
quantized_params = ParamQuantizer(self.quantize_param_map).quantize(
mlc_name, param
)
for quantized_name, quantized_param in quantized_params:
logger.info(
' Quantized Parameter: "%s", shape: %s, dtype: %s',
quantized_name,
quantized_param.shape,
quantized_param.dtype,
)
yield quantized_name, quantized_param
else:
yield mlc_name, param
cached_files = list(self.cached_files.keys())
for path in cached_files:
self._unload_file(path)
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/parameter/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class QuantizeMapping:
used to convert the quantized parameters into the desired form.
"""

param_map: Dict[str, Callable[[str], List[str]]]
param_map: Dict[str, List[str]]
map_func: Dict[str, Callable[[NDArray], List[NDArray]]]


Expand Down
38 changes: 37 additions & 1 deletion python/mlc_chat/compiler/parameter/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
"""Common utilities for loading parameters"""
# pylint: disable=too-few-public-methods
import logging
from pathlib import Path
from typing import Iterator, Set, Tuple
from typing import TYPE_CHECKING, Iterator, Set, Tuple

import numpy as np

from .mapping import ExternMapping

if TYPE_CHECKING:
from tvm.runtime import NDArray

from ..parameter import QuantizeMapping

logger = logging.getLogger(__name__)


class ParamQuantizer:
"""A parameter quantizer that quantizes given mlc-llm parameters"""

quantize_map: "QuantizeMapping"

def __init__(self, quantize_map: "QuantizeMapping") -> None:
self.quantize_map = quantize_map

def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]:
"""Apply quantization to the given parameters
Parameters
----------
name : str
The name of the parameter
param : NDArray
The parameter to be quantized
Returns
-------
List[Tuple[str, NDArray]]
The quantized parameters, each with its name
"""

assert name in self.quantize_map.param_map
quantized_names = self.quantize_map.param_map[name]
quantized_params = self.quantize_map.map_func[name](param)
return zip(quantized_names, quantized_params)


def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]):
"""Check that all external parameters have been used and are stored in the weights file."""
used_extern_names = set(sum(param_map.param_map.values(), []))
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""A subpackage for quantization and dequantization algorithms"""
from .quantization import QUANT
from .quantization import QUANT, QuantizeConfig
70 changes: 70 additions & 0 deletions python/mlc_chat/compiler/quantization/group_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""A group quantizer for on the fly parameter quantization"""
# pylint: disable=too-few-public-methods

from typing import List, Tuple

from tvm import te, tir

from .quantization import QuantizeConfig


def te_quantize(
weight: te.Tensor, config: QuantizeConfig
) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]:
"""Group quantization for weight tensor, defined in tensor expression."""
# pylint: disable=too-many-locals
assert len(weight.shape) == 2
n, m = weight.shape
# compute scale per group
r = te.reduce_axis((0, config.group_size), name="r")
num_group = tir.ceildiv(m, config.group_size)
scale_shape = (n, num_group)
max_abs = te.compute(
shape=scale_shape,
fcompute=lambda i, j: te.max(
tir.if_then_else(
j * config.group_size + r < weight.shape[1],
te.abs(weight[i, j * config.group_size + r]),
tir.const(1e-4, config.weight_dtype),
),
axis=r,
),
name="max_abs_value",
)
scale = te.compute(
(n, m),
lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype),
name="scale",
)

# compute scaled weight
tir_max_int = tir.const(config.max_int_value, config.weight_dtype)
tir_zero = tir.const(0, config.weight_dtype)
tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype)
scaled_weight = te.compute(
shape=weight.shape,
fcompute=lambda i, j: tir.min(
tir.max(
tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int),
tir_zero,
),
tir_max_int_2,
).astype(config.storage_dtype),
)

# compute quantized weight per storage
r = te.reduce_axis((0, config.num_elem_per_storage), name="r")
num_storage = config.num_storage_per_group * num_group
quantized_weight_shape = (n, num_storage)
quantized_weight = te.compute(
shape=quantized_weight_shape,
fcompute=lambda i, j: tir.sum(
scaled_weight[i, j * config.num_elem_per_storage + r]
<< (r * config.quantize_dtype_bits),
axis=r,
where=j * config.num_elem_per_storage + r < m,
),
name="weight",
)
return quantized_weight, scale, [max_abs, scaled_weight]
# pylint: enable=too-many-locals
Loading

0 comments on commit c0c3a8d

Please sign in to comment.