diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 902886bb..62af5498 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -15,7 +15,8 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch import torch import diffusers from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING @@ -142,12 +143,29 @@ class CogView4(BaseModel): flush() if self.model_config.quantize: + quantization_args = self.model_config.quantize_kwargs + if 'exclude' not in quantization_args: + quantization_args['exclude'] = [] + if 'include' not in quantization_args: + quantization_args['include'] = [] + + # Be more specific with the include pattern to exactly match transformer blocks + quantization_args['include'] += ["transformer_blocks.*"] + + # Exclude all LayerNorm layers within transformer blocks + quantization_args['exclude'] += [ + "transformer_blocks.*.norm1", + "transformer_blocks.*.norm2", + "transformer_blocks.*.norm2_context", + "transformer_blocks.*.attn1.norm_q", + "transformer_blocks.*.attn1.norm_k" + ] + # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 self.print_and_status_update("Quantizing transformer") - quantize(transformer, weights=quantization_type, - **self.model_config.quantize_kwargs) + quantize(transformer, weights=quantization_type, **quantization_args) freeze(transformer) transformer.to(self.device_torch) else: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index bf93e84c..65736178 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py new file mode 100644 index 00000000..fd7b3178 --- /dev/null +++ b/toolkit/util/quantize.py @@ -0,0 +1,55 @@ +from fnmatch import fnmatch +from typing import Any, Dict, List, Optional, Union +import torch + +from optimum.quanto.quantize import _quantize_submodule +from optimum.quanto.tensor import Optimizer, qtype + +# the quantize function in quanto had a bug where it was using exclude instead of include + + +def quantize( + model: torch.nn.Module, + weights: Optional[Union[str, qtype]] = None, + activations: Optional[Union[str, qtype]] = None, + optimizer: Optional[Optimizer] = None, + include: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, +): + """Quantize the specified model submodules + + Recursively quantize the submodules of the specified parent model. + + Only modules that have quantized counterparts will be quantized. + + If include patterns are specified, the submodule name must match one of them. + + If exclude patterns are specified, the submodule must not match one of them. + + Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Note: quantization happens in-place and modifies the original model and its descendants. + + Args: + model (`torch.nn.Module`): the model whose submodules will be quantized. + weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. + activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. + include (`Optional[Union[str, List[str]]]`): + Patterns constituting the allowlist. If provided, module names must match at + least one pattern from the allowlist. + exclude (`Optional[Union[str, List[str]]]`): + Patterns constituting the denylist. If provided, module names must not match + any patterns from the denylist. + """ + if include is not None: + include = [include] if isinstance(include, str) else include + if exclude is not None: + exclude = [exclude] if isinstance(exclude, str) else exclude + for name, m in model.named_modules(): + if include is not None and not any(fnmatch(name, pattern) for pattern in include): + continue + if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): + continue + _quantize_submodule(model, name, m, weights=weights, + activations=activations, optimizer=optimizer)