mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Added support for training on flux schnell. Added example config and instructions for training on flux schnell
This commit is contained in:
@@ -56,7 +56,7 @@ from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -496,10 +496,23 @@ class StableDiffusion:
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
|
||||
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
if self.model_config.lora_path:
|
||||
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
||||
|
||||
if not self.is_flux:
|
||||
raise ValueError("Assistant lora is only supported for flux models currently")
|
||||
|
||||
# handle downloading from the hub if needed
|
||||
if not os.path.exists(self.model_config.assistant_lora_path):
|
||||
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
|
||||
new_lora_path = hf_hub_download(
|
||||
self.model_config.assistant_lora_path,
|
||||
filename="pytorch_lora_weights.safetensors"
|
||||
)
|
||||
# replace the path
|
||||
self.model_config.assistant_lora_path = new_lora_path
|
||||
|
||||
if self.model_config.assistant_lora_path is not None and self.is_flux:
|
||||
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
||||
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
||||
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
||||
@@ -509,6 +522,10 @@ class StableDiffusion:
|
||||
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
print("Fusing in LoRA")
|
||||
# if doing low vram, do this on the gpu, painfully slow otherwise
|
||||
if self.low_vram:
|
||||
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
|
||||
Reference in New Issue
Block a user