For lora assisted training, merge in before quantizing then sample with schnell at -1 weight. Almost doubles training speed with lora adapter.

This commit is contained in:
Jaret Burkett
2024-08-16 17:28:44 -06:00
parent 165510ace2
commit 452e0e286d

View File

@@ -184,6 +184,9 @@ class StableDiffusion:
self.quantize_device = quantize_device if quantize_device is not None else self.device
self.low_vram = self.model_config.low_vram
# merge in and preview active with -1 weight
self.invert_assistant_lora = False
def load_model(self):
if self.is_loaded:
return
@@ -493,6 +496,18 @@ class StableDiffusion:
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if self.model_config.assistant_lora_path is not None and self.is_flux:
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
# so we will merge in now and sample with -1 weight later
self.invert_assistant_lora = True
# trigger it to get merged in
self.model_config.lora_path = self.model_config.assistant_lora_path
if self.model_config.lora_path is not None:
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
@@ -677,6 +692,11 @@ class StableDiffusion:
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
self.model_config.assistant_lora_path, self)
if self.invert_assistant_lora:
# invert and disable during training
self.assistant_lora.multiplier = -1.0
self.assistant_lora.is_active = False
if self.is_pixart and self.vae_scale_factor == 16:
# TODO make our own pipeline?
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
@@ -743,12 +763,16 @@ class StableDiffusion:
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
):
merge_multiplier = 1.0
flush()
# if using assistant, unfuse it
if self.model_config.assistant_lora_path is not None:
print("Unloading asistant lora")
# unfortunately, not an easier way with peft
self.assistant_lora.is_active = False
print("Unloading assistant lora")
if self.invert_assistant_lora:
self.assistant_lora.is_active = True
# move weights on to the device
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
else:
self.assistant_lora.is_active = False
if self.network is not None:
self.network.eval()
@@ -1257,9 +1281,13 @@ class StableDiffusion:
# refuse loras
if self.model_config.assistant_lora_path is not None:
print("Loading asistant lora")
# unfortunately, not an easier way with peft
self.assistant_lora.is_active = True
print("Loading assistant lora")
if self.invert_assistant_lora:
self.assistant_lora.is_active = False
# move weights off the device
self.assistant_lora.force_to('cpu', self.torch_dtype)
else:
self.assistant_lora.is_active = True
def get_latent_noise(
self,