mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added support for training on primary gpu with low_vram flag. Updated example script to remove creepy horse sample at that seed
This commit is contained in:
@@ -56,7 +56,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
@@ -174,6 +174,7 @@ class StableDiffusion:
|
||||
self.is_flow_matching = True
|
||||
|
||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||
self.low_vram = self.model_config.low_vram
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -472,7 +473,9 @@ class StableDiffusion:
|
||||
# low_cpu_mem_usage=False,
|
||||
# device_map=None
|
||||
)
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
@@ -493,8 +496,9 @@ class StableDiffusion:
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=qfloat8)
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user