mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with picking layers for quantization, adjust layers fo better quantization of cogview4
This commit is contained in:
@@ -15,7 +15,8 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
import torch
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -142,12 +143,29 @@ class CogView4(BaseModel):
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_args = self.model_config.quantize_kwargs
|
||||
if 'exclude' not in quantization_args:
|
||||
quantization_args['exclude'] = []
|
||||
if 'include' not in quantization_args:
|
||||
quantization_args['include'] = []
|
||||
|
||||
# Be more specific with the include pattern to exactly match transformer blocks
|
||||
quantization_args['include'] += ["transformer_blocks.*"]
|
||||
|
||||
# Exclude all LayerNorm layers within transformer blocks
|
||||
quantization_args['exclude'] += [
|
||||
"transformer_blocks.*.norm1",
|
||||
"transformer_blocks.*.norm2",
|
||||
"transformer_blocks.*.norm2_context",
|
||||
"transformer_blocks.*.attn1.norm_q",
|
||||
"transformer_blocks.*.attn1.norm_k"
|
||||
]
|
||||
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
quantize(transformer, weights=quantization_type, **quantization_args)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
|
||||
@@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
|
||||
55
toolkit/util/quantize.py
Normal file
55
toolkit/util/quantize.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
from optimum.quanto.tensor import Optimizer, qtype
|
||||
|
||||
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||
|
||||
|
||||
def quantize(
|
||||
model: torch.nn.Module,
|
||||
weights: Optional[Union[str, qtype]] = None,
|
||||
activations: Optional[Union[str, qtype]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
include: Optional[Union[str, List[str]]] = None,
|
||||
exclude: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
"""Quantize the specified model submodules
|
||||
|
||||
Recursively quantize the submodules of the specified parent model.
|
||||
|
||||
Only modules that have quantized counterparts will be quantized.
|
||||
|
||||
If include patterns are specified, the submodule name must match one of them.
|
||||
|
||||
If exclude patterns are specified, the submodule must not match one of them.
|
||||
|
||||
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
||||
https://docs.python.org/3/library/fnmatch.html for more details.
|
||||
|
||||
Note: quantization happens in-place and modifies the original model and its descendants.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model whose submodules will be quantized.
|
||||
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
|
||||
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
|
||||
include (`Optional[Union[str, List[str]]]`):
|
||||
Patterns constituting the allowlist. If provided, module names must match at
|
||||
least one pattern from the allowlist.
|
||||
exclude (`Optional[Union[str, List[str]]]`):
|
||||
Patterns constituting the denylist. If provided, module names must not match
|
||||
any patterns from the denylist.
|
||||
"""
|
||||
if include is not None:
|
||||
include = [include] if isinstance(include, str) else include
|
||||
if exclude is not None:
|
||||
exclude = [exclude] if isinstance(exclude, str) else exclude
|
||||
for name, m in model.named_modules():
|
||||
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
|
||||
continue
|
||||
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
||||
continue
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
Reference in New Issue
Block a user