From f5aa4232fabbe8e37191cb217c6cf38ab001ea23 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 20 Mar 2025 16:28:54 -0600 Subject: [PATCH] Added ability to quantize with torchao --- requirements.txt | 4 +-- toolkit/config_modules.py | 2 ++ toolkit/models/base_model.py | 7 +---- toolkit/models/cogview4.py | 6 ++--- toolkit/models/wan21/wan21.py | 6 ++--- toolkit/stable_diffusion_model.py | 14 +++++----- toolkit/util/quantize.py | 44 +++++++++++++++++++++++++++---- 7 files changed, 57 insertions(+), 26 deletions(-) diff --git a/requirements.txt b/requirements.txt index a9831469..b2ab0286 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -torch==2.5.1 -torchvision==0.20.1 +torch==2.6.0 +torchvision==0.21.0 safetensors git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b transformers==4.49.0 diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a3bcd1ce..c76e4e1f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -508,6 +508,8 @@ class ModelConfig: # only for flux for now self.quantize = kwargs.get("quantize", False) self.quantize_te = kwargs.get("quantize_te", self.quantize) + self.qtype = kwargs.get("qtype", "qfloat8") + self.qtype_te = kwargs.get("qtype_te", "qfloat8") self.low_vram = kwargs.get("low_vram", False) self.attn_masking = kwargs.get("attn_masking", False) if self.attn_masking and not self.is_flux: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 48a04ca4..8adcffa8 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -23,27 +23,22 @@ from toolkit.models.decorator import Decorator from toolkit.paths import KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter -from toolkit.saving import save_ldm_model_from_diffusers from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ - LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ - FluxTransformer2DModel -from toolkit.models.lumina2 import Lumina2Transformer2DModel + LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline -from transformers import T5EncoderModel, UMT5EncoderModel from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc -from transformers import Gemma2Model, Qwen2Model, LlamaModel if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 593fa977..bdb692e0 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -19,7 +19,7 @@ import torch import diffusers from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline from optimum.quanto import freeze, qfloat8, QTensor, qint4 -from toolkit.util.quantize import quantize +from toolkit.util.quantize import quantize, get_qtype from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING @@ -114,7 +114,7 @@ class CogView4(BaseModel): if self.model_config.quantize_te: self.print_and_status_update("Quantizing GlmModel") - quantize(text_encoder, weights=qfloat8) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder) flush() @@ -166,7 +166,7 @@ class CogView4(BaseModel): # patch the state dict method patch_dequantization_on_save(transformer) - quantization_type = qfloat8 + quantization_type = get_qtype(self.model_config.qtype) self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **quantization_args) freeze(transformer) diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index e6c52f10..a8bb7446 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -29,7 +29,7 @@ import copy from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch import torch from optimum.quanto import freeze, qfloat8, QTensor, qint4 -from toolkit.util.quantize import quantize +from toolkit.util.quantize import quantize, get_qtype from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler from typing import TYPE_CHECKING, List from toolkit.accelerator import unwrap_model @@ -377,7 +377,7 @@ class Wan21(BaseModel): quantization_args['exclude'] = [] # patch the state dict method patch_dequantization_on_save(transformer) - quantization_type = qfloat8 + quantization_type = get_qtype(self.model_config.qtype) self.print_and_status_update("Quantizing transformer") if self.model_config.low_vram: print("Quantizing blocks") @@ -425,7 +425,7 @@ class Wan21(BaseModel): if self.model_config.quantize_te: self.print_and_status_update("Quantizing UMT5EncoderModel") - quantize(text_encoder, weights=qfloat8) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder) flush() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5f459e56..4a63bc68 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -66,7 +66,7 @@ from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance from optimum.quanto import freeze, qfloat8, QTensor, qint4 -from toolkit.util.quantize import quantize +from toolkit.util.quantize import quantize, get_qtype from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc @@ -368,7 +368,7 @@ class StableDiffusion: raise ValueError("LoRA is not supported for SD3 models currently") if self.model_config.quantize: - quantization_type = qfloat8 + quantization_type = get_qtype(self.model_config.qtype) print_acc("Quantizing transformer") quantize(transformer, weights=quantization_type) freeze(transformer) @@ -394,7 +394,7 @@ class StableDiffusion: if self.model_config.quantize: print_acc("Quantizing T5") - quantize(text_encoder_3, weights=qfloat8) + quantize(text_encoder_3, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder_3) flush() @@ -739,7 +739,7 @@ class StableDiffusion: if self.model_config.quantize: # patch the state dict method patch_dequantization_on_save(transformer) - quantization_type = qfloat8 + quantization_type = get_qtype(self.model_config.qtype) self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) @@ -772,7 +772,7 @@ class StableDiffusion: self.print_and_status_update("Quantizing LLM") else: self.print_and_status_update("Quantizing T5") - quantize(text_encoder_2, weights=qfloat8) + quantize(text_encoder_2, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder_2) flush() @@ -853,7 +853,7 @@ class StableDiffusion: if self.model_config.quantize: # patch the state dict method patch_dequantization_on_save(transformer) - quantization_type = qfloat8 + quantization_type = get_qtype(self.model_config.qtype) self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) @@ -882,7 +882,7 @@ class StableDiffusion: if self.model_config.quantize_te: self.print_and_status_update("Quantizing Gemma2") - quantize(text_encoder, weights=qfloat8) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder) flush() diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index 9d81856b..d19c81cf 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -1,17 +1,48 @@ from fnmatch import fnmatch from typing import Any, Dict, List, Optional, Union import torch +from dataclasses import dataclass from optimum.quanto.quantize import _quantize_submodule -from optimum.quanto.tensor import Optimizer, qtype +from optimum.quanto.tensor import Optimizer, qtype, qtypes +from torchao.quantization.quant_api import ( + quantize_ as torchao_quantize_, + Float8WeightOnlyConfig, + UIntXWeightOnlyConfig +) # the quantize function in quanto had a bug where it was using exclude instead of include Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] +torchao_qtypes = { + # "int4": Int4WeightOnlyConfig(), + "uint2": UIntXWeightOnlyConfig(torch.uint2), + "uint3": UIntXWeightOnlyConfig(torch.uint3), + "uint4": UIntXWeightOnlyConfig(torch.uint4), + "uint5": UIntXWeightOnlyConfig(torch.uint5), + "uint6": UIntXWeightOnlyConfig(torch.uint6), + "uint7": UIntXWeightOnlyConfig(torch.uint7), + "uint8": UIntXWeightOnlyConfig(torch.uint8), + "float8": Float8WeightOnlyConfig(), +} + +class aotype: + def __init__(self, name: str): + self.name = name + self.config = torchao_qtypes[name] + +def get_qtype(qtype: Union[str, qtype]) -> qtype: + if qtype in torchao_qtypes: + return aotype(qtype) + if isinstance(qtype, str): + return qtypes[qtype] + else: + return qtype + def quantize( model: torch.nn.Module, - weights: Optional[Union[str, qtype]] = None, + weights: Optional[Union[str, qtype, aotype]] = None, activations: Optional[Union[str, qtype]] = None, optimizer: Optional[Optimizer] = None, include: Optional[Union[str, List[str]]] = None, @@ -57,8 +88,11 @@ def quantize( if m.__class__.__name__ in Q_MODULES: continue else: - _quantize_submodule(model, name, m, weights=weights, - activations=activations, optimizer=optimizer) + if isinstance(weights, aotype): + torchao_quantize_(m, weights.config) + else: + _quantize_submodule(model, name, m, weights=weights, + activations=activations, optimizer=optimizer) except Exception as e: print(f"Failed to quantize {name}: {e}") - raise e + raise e \ No newline at end of file