mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 07:29:56 +00:00
For lora assisted training, merge in before quantizing then sample with schnell at -1 weight. Almost doubles training speed with lora adapter.
This commit is contained in:
@@ -184,6 +184,9 @@ class StableDiffusion:
|
||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||
self.low_vram = self.model_config.low_vram
|
||||
|
||||
# merge in and preview active with -1 weight
|
||||
self.invert_assistant_lora = False
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -493,6 +496,18 @@ class StableDiffusion:
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
|
||||
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
||||
|
||||
if self.model_config.assistant_lora_path is not None and self.is_flux:
|
||||
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
||||
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
||||
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
||||
# so we will merge in now and sample with -1 weight later
|
||||
self.invert_assistant_lora = True
|
||||
# trigger it to get merged in
|
||||
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
@@ -677,6 +692,11 @@ class StableDiffusion:
|
||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||
self.model_config.assistant_lora_path, self)
|
||||
|
||||
if self.invert_assistant_lora:
|
||||
# invert and disable during training
|
||||
self.assistant_lora.multiplier = -1.0
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.is_pixart and self.vae_scale_factor == 16:
|
||||
# TODO make our own pipeline?
|
||||
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
||||
@@ -743,12 +763,16 @@ class StableDiffusion:
|
||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||
):
|
||||
merge_multiplier = 1.0
|
||||
|
||||
flush()
|
||||
# if using assistant, unfuse it
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Unloading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.assistant_lora.is_active = False
|
||||
print("Unloading assistant lora")
|
||||
if self.invert_assistant_lora:
|
||||
self.assistant_lora.is_active = True
|
||||
# move weights on to the device
|
||||
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
||||
else:
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
@@ -1257,9 +1281,13 @@ class StableDiffusion:
|
||||
|
||||
# refuse loras
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.assistant_lora.is_active = True
|
||||
print("Loading assistant lora")
|
||||
if self.invert_assistant_lora:
|
||||
self.assistant_lora.is_active = False
|
||||
# move weights off the device
|
||||
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
||||
else:
|
||||
self.assistant_lora.is_active = True
|
||||
|
||||
def get_latent_noise(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user