mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to quantize with torchao
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
|
||||
transformers==4.49.0
|
||||
|
||||
@@ -508,6 +508,8 @@ class ModelConfig:
|
||||
# only for flux for now
|
||||
self.quantize = kwargs.get("quantize", False)
|
||||
self.quantize_te = kwargs.get("quantize_te", self.quantize)
|
||||
self.qtype = kwargs.get("qtype", "qfloat8")
|
||||
self.qtype_te = kwargs.get("qtype_te", "qfloat8")
|
||||
self.low_vram = kwargs.get("low_vram", False)
|
||||
self.attn_masking = kwargs.get("attn_masking", False)
|
||||
if self.attn_masking and not self.is_flux:
|
||||
|
||||
@@ -23,27 +23,22 @@ from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
FluxTransformer2DModel
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
from transformers import T5EncoderModel, UMT5EncoderModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from transformers import Gemma2Model, Qwen2Model, LlamaModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -114,7 +114,7 @@ class CogView4(BaseModel):
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing GlmModel")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
@@ -166,7 +166,7 @@ class CogView4(BaseModel):
|
||||
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **quantization_args)
|
||||
freeze(transformer)
|
||||
|
||||
@@ -29,7 +29,7 @@ import copy
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
import torch
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from typing import TYPE_CHECKING, List
|
||||
from toolkit.accelerator import unwrap_model
|
||||
@@ -377,7 +377,7 @@ class Wan21(BaseModel):
|
||||
quantization_args['exclude'] = []
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
if self.model_config.low_vram:
|
||||
print("Quantizing blocks")
|
||||
@@ -425,7 +425,7 @@ class Wan21(BaseModel):
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing UMT5EncoderModel")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
@@ -368,7 +368,7 @@ class StableDiffusion:
|
||||
raise ValueError("LoRA is not supported for SD3 models currently")
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
print_acc("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
@@ -394,7 +394,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize:
|
||||
print_acc("Quantizing T5")
|
||||
quantize(text_encoder_3, weights=qfloat8)
|
||||
quantize(text_encoder_3, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder_3)
|
||||
flush()
|
||||
|
||||
@@ -739,7 +739,7 @@ class StableDiffusion:
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
@@ -772,7 +772,7 @@ class StableDiffusion:
|
||||
self.print_and_status_update("Quantizing LLM")
|
||||
else:
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
quantize(text_encoder_2, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
@@ -853,7 +853,7 @@ class StableDiffusion:
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
@@ -882,7 +882,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Gemma2")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
|
||||
@@ -1,17 +1,48 @@
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
from optimum.quanto.tensor import Optimizer, qtype
|
||||
from optimum.quanto.tensor import Optimizer, qtype, qtypes
|
||||
from torchao.quantization.quant_api import (
|
||||
quantize_ as torchao_quantize_,
|
||||
Float8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig
|
||||
)
|
||||
|
||||
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||
|
||||
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
|
||||
|
||||
torchao_qtypes = {
|
||||
# "int4": Int4WeightOnlyConfig(),
|
||||
"uint2": UIntXWeightOnlyConfig(torch.uint2),
|
||||
"uint3": UIntXWeightOnlyConfig(torch.uint3),
|
||||
"uint4": UIntXWeightOnlyConfig(torch.uint4),
|
||||
"uint5": UIntXWeightOnlyConfig(torch.uint5),
|
||||
"uint6": UIntXWeightOnlyConfig(torch.uint6),
|
||||
"uint7": UIntXWeightOnlyConfig(torch.uint7),
|
||||
"uint8": UIntXWeightOnlyConfig(torch.uint8),
|
||||
"float8": Float8WeightOnlyConfig(),
|
||||
}
|
||||
|
||||
class aotype:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.config = torchao_qtypes[name]
|
||||
|
||||
def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
||||
if qtype in torchao_qtypes:
|
||||
return aotype(qtype)
|
||||
if isinstance(qtype, str):
|
||||
return qtypes[qtype]
|
||||
else:
|
||||
return qtype
|
||||
|
||||
def quantize(
|
||||
model: torch.nn.Module,
|
||||
weights: Optional[Union[str, qtype]] = None,
|
||||
weights: Optional[Union[str, qtype, aotype]] = None,
|
||||
activations: Optional[Union[str, qtype]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
include: Optional[Union[str, List[str]]] = None,
|
||||
@@ -57,8 +88,11 @@ def quantize(
|
||||
if m.__class__.__name__ in Q_MODULES:
|
||||
continue
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
if isinstance(weights, aotype):
|
||||
torchao_quantize_(m, weights.config)
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize {name}: {e}")
|
||||
raise e
|
||||
raise e
|
||||
Reference in New Issue
Block a user