mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Added ability to quantize with torchao
This commit is contained in:
@@ -66,7 +66,7 @@ from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from toolkit.util.quantize import quantize
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
@@ -368,7 +368,7 @@ class StableDiffusion:
|
||||
raise ValueError("LoRA is not supported for SD3 models currently")
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
print_acc("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
@@ -394,7 +394,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize:
|
||||
print_acc("Quantizing T5")
|
||||
quantize(text_encoder_3, weights=qfloat8)
|
||||
quantize(text_encoder_3, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder_3)
|
||||
flush()
|
||||
|
||||
@@ -739,7 +739,7 @@ class StableDiffusion:
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
@@ -772,7 +772,7 @@ class StableDiffusion:
|
||||
self.print_and_status_update("Quantizing LLM")
|
||||
else:
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
quantize(text_encoder_2, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
@@ -853,7 +853,7 @@ class StableDiffusion:
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
@@ -882,7 +882,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Gemma2")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user