Fix issue with picking layers for quantization, adjust layers fo better quantization of cogview4

This commit is contained in:
Jaret Burkett
2025-03-05 13:44:40 -07:00
parent aa44828c0c
commit 4fe33f51c1
3 changed files with 78 additions and 4 deletions

View File

@@ -15,7 +15,8 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
import torch
import diffusers
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
@@ -142,12 +143,29 @@ class CogView4(BaseModel):
flush()
if self.model_config.quantize:
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
if 'include' not in quantization_args:
quantization_args['include'] = []
# Be more specific with the include pattern to exactly match transformer blocks
quantization_args['include'] += ["transformer_blocks.*"]
# Exclude all LayerNorm layers within transformer blocks
quantization_args['exclude'] += [
"transformer_blocks.*.norm1",
"transformer_blocks.*.norm2",
"transformer_blocks.*.norm2_context",
"transformer_blocks.*.attn1.norm_q",
"transformer_blocks.*.attn1.norm_k"
]
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
quantize(transformer, weights=quantization_type, **quantization_args)
freeze(transformer)
transformer.to(self.device_torch)
else:

View File

@@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
from toolkit.print import print_acc

55
toolkit/util/quantize.py Normal file
View File

@@ -0,0 +1,55 @@
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Union
import torch
from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.tensor import Optimizer, qtype
# the quantize function in quanto had a bug where it was using exclude instead of include
def quantize(
model: torch.nn.Module,
weights: Optional[Union[str, qtype]] = None,
activations: Optional[Union[str, qtype]] = None,
optimizer: Optional[Optimizer] = None,
include: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
):
"""Quantize the specified model submodules
Recursively quantize the submodules of the specified parent model.
Only modules that have quantized counterparts will be quantized.
If include patterns are specified, the submodule name must match one of them.
If exclude patterns are specified, the submodule must not match one of them.
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Note: quantization happens in-place and modifies the original model and its descendants.
Args:
model (`torch.nn.Module`): the model whose submodules will be quantized.
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
include (`Optional[Union[str, List[str]]]`):
Patterns constituting the allowlist. If provided, module names must match at
least one pattern from the allowlist.
exclude (`Optional[Union[str, List[str]]]`):
Patterns constituting the denylist. If provided, module names must not match
any patterns from the denylist.
"""
if include is not None:
include = [include] if isinstance(include, str) else include
if exclude is not None:
exclude = [exclude] if isinstance(exclude, str) else exclude
for name, m in model.named_modules():
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
continue
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
continue
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)