forked from mlc-ai/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Slim-LM] Enable Group Quant (mlc-ai#1129)
* Enable group quant via new interface. * Minor fix. * Linting. * Fix isort. * Fix mypy. * TE compute working. * Skip embed. * Support cpu+gpu quantization. * Add target option to tests. * Linting.
- Loading branch information
Showing
7 changed files
with
401 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
Quantization specs for Llama2 architecture. | ||
TODO: add docstring | ||
""" | ||
from typing import Callable, Dict, List, Optional | ||
|
||
import tvm | ||
from tvm.runtime import NDArray | ||
|
||
from ..parameter import QuantizeMapping | ||
from ..quantization import QuantizeConfig | ||
from ..quantization.group_quantizer import te_quantize as te_group_quantize | ||
from .llama_config import LlamaConfig | ||
from .llama_model import LlamaForCasualLM | ||
|
||
|
||
def huggingface_group_quantize( | ||
model_config: LlamaConfig, | ||
quantize_config: QuantizeConfig, | ||
target: Optional[tvm.target.Target] = None, | ||
) -> QuantizeMapping: | ||
"""Returns a parameter mapping that maps a parameter in MLC LLM's model | ||
definition to its eventual names and values after quantization. | ||
Parameters | ||
---------- | ||
model_config : LlamaConfig | ||
The configuration of the Llama model. | ||
quantize_config : GroupQuantizeConfig | ||
The configuration of the group quantization. | ||
target : Optional[tvm.target.Target] | ||
The target device to run the quantization on, by default None, which | ||
means the quantization will be run on CPU. | ||
Returns | ||
------- | ||
quantize_map : QuantizeMapping | ||
The parameter mapping from a parameter in MLC LLM's model definition to | ||
its eventual names and values after quantization. | ||
""" | ||
|
||
def group_quantize( | ||
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None | ||
): | ||
if target is None or target.kind.name == "llvm": | ||
target = tvm.target.Target("llvm") | ||
device = tvm.cpu() | ||
elif target.kind.name == "cuda": | ||
device = tvm.cuda() | ||
else: | ||
raise ValueError(f"Invalid target device: {target}") | ||
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") | ||
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore | ||
param_tensor, config | ||
) | ||
s = tvm.te.create_schedule( | ||
[compute.op for compute in [weight_compute, scale_compute] + other_computes] | ||
) | ||
if target.kind.name == "cuda": | ||
# thread_binding for cuda | ||
for compute in [weight_compute, scale_compute] + other_computes: | ||
xo, xi = s[compute].split(compute.op.axis[0], 256) | ||
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) | ||
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) | ||
f_quantize = tvm.build( | ||
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target | ||
) | ||
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) | ||
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) | ||
f_quantize(param.copyto(device), weight, scale) | ||
return weight, scale | ||
|
||
# Param check | ||
assert ( | ||
quantize_config.kind == "group_quantize" | ||
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" | ||
assert ( | ||
quantize_config.name == "q4f16_1" | ||
), """Only support q4f16_1 quantization scheme for now.""" | ||
|
||
# Fetch model parameter & names | ||
model = LlamaForCasualLM(model_config) | ||
_, named_params = model.export_tvm(spec=model.get_default_spec()) | ||
parameter_names = {name for name, _ in named_params} | ||
|
||
# Init mappings | ||
param_map: Dict[str, List[str]] = {} | ||
map_func: Dict[str, Callable] = {} | ||
|
||
# Dispatch quantization scheme | ||
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py | ||
for name in parameter_names: | ||
if "norm.weight" not in name and "embed" not in name: | ||
param_map[name] = [f"{name}_quantized", f"{name}_scale"] | ||
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) | ||
else: | ||
# skip these parameters | ||
param_map[name] = [name] | ||
map_func[name] = lambda x: [x] | ||
|
||
return QuantizeMapping(param_map, map_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
"""A subpackage for quantization and dequantization algorithms""" | ||
from .quantization import QUANT | ||
from .quantization import QUANT, QuantizeConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""A group quantizer for on the fly parameter quantization""" | ||
# pylint: disable=too-few-public-methods | ||
|
||
from typing import List, Tuple | ||
|
||
from tvm import te, tir | ||
|
||
from .quantization import QuantizeConfig | ||
|
||
|
||
def te_quantize( | ||
weight: te.Tensor, config: QuantizeConfig | ||
) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: | ||
"""Group quantization for weight tensor, defined in tensor expression.""" | ||
# pylint: disable=too-many-locals | ||
assert len(weight.shape) == 2 | ||
n, m = weight.shape | ||
# compute scale per group | ||
r = te.reduce_axis((0, config.group_size), name="r") | ||
num_group = tir.ceildiv(m, config.group_size) | ||
scale_shape = (n, num_group) | ||
max_abs = te.compute( | ||
shape=scale_shape, | ||
fcompute=lambda i, j: te.max( | ||
tir.if_then_else( | ||
j * config.group_size + r < weight.shape[1], | ||
te.abs(weight[i, j * config.group_size + r]), | ||
tir.const(1e-4, config.weight_dtype), | ||
), | ||
axis=r, | ||
), | ||
name="max_abs_value", | ||
) | ||
scale = te.compute( | ||
(n, m), | ||
lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), | ||
name="scale", | ||
) | ||
|
||
# compute scaled weight | ||
tir_max_int = tir.const(config.max_int_value, config.weight_dtype) | ||
tir_zero = tir.const(0, config.weight_dtype) | ||
tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) | ||
scaled_weight = te.compute( | ||
shape=weight.shape, | ||
fcompute=lambda i, j: tir.min( | ||
tir.max( | ||
tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), | ||
tir_zero, | ||
), | ||
tir_max_int_2, | ||
).astype(config.storage_dtype), | ||
) | ||
|
||
# compute quantized weight per storage | ||
r = te.reduce_axis((0, config.num_elem_per_storage), name="r") | ||
num_storage = config.num_storage_per_group * num_group | ||
quantized_weight_shape = (n, num_storage) | ||
quantized_weight = te.compute( | ||
shape=quantized_weight_shape, | ||
fcompute=lambda i, j: tir.sum( | ||
scaled_weight[i, j * config.num_elem_per_storage + r] | ||
<< (r * config.quantize_dtype_bits), | ||
axis=r, | ||
where=j * config.num_elem_per_storage + r < m, | ||
), | ||
name="weight", | ||
) | ||
return quantized_weight, scale, [max_abs, scaled_weight] | ||
# pylint: enable=too-many-locals |
Oops, something went wrong.