mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added ability to quantize with torchao
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -114,7 +114,7 @@ class CogView4(BaseModel):
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing GlmModel")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
@@ -166,7 +166,7 @@ class CogView4(BaseModel):
|
||||
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **quantization_args)
|
||||
freeze(transformer)
|
||||
|
||||
Reference in New Issue
Block a user