Added support for training on primary gpu with low_vram flag. Updated example script to remove creepy horse sample at that seed

This commit is contained in:
Jaret Burkett
2024-08-11 09:54:30 -06:00
parent fa02e774b0
commit ec1ea7aa0e
4 changed files with 30 additions and 13 deletions

View File

@@ -56,7 +56,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -174,6 +174,7 @@ class StableDiffusion:
self.is_flow_matching = True
self.quantize_device = quantize_device if quantize_device is not None else self.device
self.low_vram = self.model_config.low_vram
def load_model(self):
if self.is_loaded:
@@ -472,7 +473,9 @@ class StableDiffusion:
# low_cpu_mem_usage=False,
# device_map=None
)
transformer.to(torch.device(self.quantize_device), dtype=dtype)
if not self.low_vram:
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.lora_path is not None:
@@ -493,8 +496,9 @@ class StableDiffusion:
pipe.unload_lora_weights()
if self.model_config.quantize:
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
quantize(transformer, weights=quantization_type)
freeze(transformer)
transformer.to(self.device_torch)
else: