Added support for training on flux schnell. Added example config and instructions for training on flux schnell

This commit is contained in:
Jaret Burkett
2024-08-17 06:58:39 -06:00
parent f9179540d2
commit 81899310f8
4 changed files with 144 additions and 9 deletions

View File

@@ -56,7 +56,7 @@ from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from huggingface_hub import hf_hub_download
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
@@ -496,10 +496,23 @@ class StableDiffusion:
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if self.model_config.assistant_lora_path is not None:
if self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if not self.is_flux:
raise ValueError("Assistant lora is only supported for flux models currently")
# handle downloading from the hub if needed
if not os.path.exists(self.model_config.assistant_lora_path):
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
new_lora_path = hf_hub_download(
self.model_config.assistant_lora_path,
filename="pytorch_lora_weights.safetensors"
)
# replace the path
self.model_config.assistant_lora_path = new_lora_path
if self.model_config.assistant_lora_path is not None and self.is_flux:
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
@@ -509,6 +522,10 @@ class StableDiffusion:
self.model_config.lora_path = self.model_config.assistant_lora_path
if self.model_config.lora_path is not None:
print("Fusing in LoRA")
# if doing low vram, do this on the gpu, painfully slow otherwise
if self.low_vram:
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
pipe: FluxPipeline = FluxPipeline(