Added ability to quantize with torchao

This commit is contained in:
Jaret Burkett
2025-03-20 16:28:54 -06:00
parent 3a6b24f4c8
commit f5aa4232fa
7 changed files with 57 additions and 26 deletions

View File

@@ -19,7 +19,7 @@ import torch
import diffusers
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from toolkit.util.quantize import quantize, get_qtype
from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
@@ -114,7 +114,7 @@ class CogView4(BaseModel):
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing GlmModel")
quantize(text_encoder, weights=qfloat8)
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
@@ -166,7 +166,7 @@ class CogView4(BaseModel):
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **quantization_args)
freeze(transformer)