Skip to content

Commit

Permalink
[Quantization] Switch to optimum-quanto (huggingface#31732)
Browse files Browse the repository at this point in the history
* switch to optimum-quanto rebase squach

* fix import check

* again

* test try-except

* style
  • Loading branch information
SunMarc authored Oct 2, 2024
1 parent b7474f2 commit cac4a48
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docker/transformers-quantization-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir gguf
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl

# Add quanto for quantization testing
RUN python3 -m pip install --no-cache-dir quanto
RUN python3 -m pip install --no-cache-dir optimum-quanto

# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git
Expand Down
49 changes: 36 additions & 13 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
)


if is_quanto_available():
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4

if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer

Expand Down Expand Up @@ -754,12 +755,20 @@ class QuantoQuantizedCache(QuantizedCache):

def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"

if is_optimum_quanto_available():
from optimum.quanto import MaxOptimizer, qint2, qint4
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
)
from quanto import MaxOptimizer, qint2, qint4

if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
Expand All @@ -776,8 +785,22 @@ def __init__(self, cache_config: CacheConfig) -> None:
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization

def _quantize(self, tensor, axis):
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight

scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
return qtensor
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
from quanto import AffineQuantizer

scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)

return qtensor

def _dequantize(self, qtensor):
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ModelOutput,
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
Expand Down Expand Up @@ -1674,10 +1675,10 @@ def _prepare_cache_for_generation(
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]

if cache_config.backend == "quanto" and not is_quanto_available():
if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()):
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
"Please install it via with `pip install optimum-quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/integrations/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_torch_available
from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


def replace_with_quanto_layers(
model,
Expand Down Expand Up @@ -45,7 +47,14 @@ def replace_with_quanto_layers(
should not be passed by the user.
"""
from accelerate import init_empty_weights
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8

if is_optimum_quanto_available():
from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
)
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8

w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
Expand Down
40 changes: 30 additions & 10 deletions src/transformers/quantizers/quantizer_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
from ..utils import (
is_accelerate_available,
is_optimum_quanto_available,
is_quanto_available,
is_torch_available,
logging,
)
from ..utils.quantization_config import QuantoConfig


Expand Down Expand Up @@ -57,11 +63,13 @@ def post_init(self):
)

def validate_environment(self, *args, **kwargs):
if not is_quanto_available():
raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
if not (is_optimum_quanto_available() or is_quanto_available()):
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
)
if not is_accelerate_available():
raise ImportError(
"Loading a quanto quantized model requires accelerate library (`pip install accelerate`)"
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
)

def update_device_map(self, device_map):
Expand All @@ -81,11 +89,17 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
import quanto
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
from quanto import QModuleMixin

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, quanto.QModuleMixin):
if isinstance(module, QModuleMixin):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
Expand All @@ -106,7 +120,13 @@ def check_quantized_param(
"""
Check if a parameter needs to be quantized.
"""
import quanto
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
elif is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
from quanto import QModuleMixin

device_map = kwargs.get("device_map", None)
param_device = kwargs.get("param_device", None)
Expand All @@ -119,7 +139,7 @@ def check_quantized_param(

module, tensor_name = get_module_from_name(model, param_name)
# We only quantize the weights and the bias is not quantized.
if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
if isinstance(module, QModuleMixin) and "weight" in tensor_name:
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
return not module.frozen
else:
Expand Down Expand Up @@ -162,7 +182,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
return target_dtype
else:
raise ValueError(
"You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
"You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute"
" the appropriate device map, you should upgrade your `accelerate` library,"
"`pip install --upgrade accelerate` or install it from source."
)
Expand Down Expand Up @@ -193,7 +213,7 @@ def _process_model_after_weight_loading(self, model):

@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
return True

def is_serializable(self, safe_serialization=None):
return False
6 changes: 3 additions & 3 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
is_nltk_available,
is_onnx_available,
is_optimum_available,
is_optimum_quanto_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
Expand All @@ -102,7 +103,6 @@
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
Expand Down Expand Up @@ -1194,11 +1194,11 @@ def require_auto_awq(test_case):
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)


def require_quanto(test_case):
def require_optimum_quanto(test_case):
"""
Decorator for quanto dependency
"""
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)


def require_compressed_tensors(test_case):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
is_onnx_available,
is_openai_available,
is_optimum_available,
is_optimum_quanto_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_is_optimum_quanto_available = False
try:
importlib.metadata.version("optimum_quanto")
_is_optimum_quanto_available = True
except importlib.metadata.PackageNotFoundError:
_is_optimum_quanto_available = False
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
_pandas_available = _is_package_available("pandas")
Expand Down Expand Up @@ -963,9 +969,17 @@ def is_auto_awq_available():


def is_quanto_available():
logger.warning_once(
"Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instrad `pip install optimum-quanto`"
)
return _quanto_available


def is_optimum_quanto_available():
# `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
return _is_optimum_quanto_available


def is_compressed_tensors_available():
return _compressed_tensors_available

Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
is_flaky,
require_accelerate,
require_auto_gptq,
require_quanto,
require_optimum_quanto,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
Expand Down Expand Up @@ -1941,7 +1941,7 @@ def test_generate_with_static_cache(self):
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)

@require_quanto
@require_optimum_quanto
@pytest.mark.generate
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
Expand Down
Loading

0 comments on commit cac4a48

Please sign in to comment.