mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
For lora assisted training, merge in before quantizing then sample with schnell at -1 weight. Almost doubles training speed with lora adapter.
This commit is contained in:
@@ -184,6 +184,9 @@ class StableDiffusion:
|
|||||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||||
self.low_vram = self.model_config.low_vram
|
self.low_vram = self.model_config.low_vram
|
||||||
|
|
||||||
|
# merge in and preview active with -1 weight
|
||||||
|
self.invert_assistant_lora = False
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
return
|
return
|
||||||
@@ -493,6 +496,18 @@ class StableDiffusion:
|
|||||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
|
||||||
|
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
||||||
|
|
||||||
|
if self.model_config.assistant_lora_path is not None and self.is_flux:
|
||||||
|
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
||||||
|
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
||||||
|
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
||||||
|
# so we will merge in now and sample with -1 weight later
|
||||||
|
self.invert_assistant_lora = True
|
||||||
|
# trigger it to get merged in
|
||||||
|
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||||
|
|
||||||
if self.model_config.lora_path is not None:
|
if self.model_config.lora_path is not None:
|
||||||
# need the pipe to do this unfortunately for now
|
# need the pipe to do this unfortunately for now
|
||||||
# we have to fuse in the weights before quantizing
|
# we have to fuse in the weights before quantizing
|
||||||
@@ -677,6 +692,11 @@ class StableDiffusion:
|
|||||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||||
self.model_config.assistant_lora_path, self)
|
self.model_config.assistant_lora_path, self)
|
||||||
|
|
||||||
|
if self.invert_assistant_lora:
|
||||||
|
# invert and disable during training
|
||||||
|
self.assistant_lora.multiplier = -1.0
|
||||||
|
self.assistant_lora.is_active = False
|
||||||
|
|
||||||
if self.is_pixart and self.vae_scale_factor == 16:
|
if self.is_pixart and self.vae_scale_factor == 16:
|
||||||
# TODO make our own pipeline?
|
# TODO make our own pipeline?
|
||||||
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
||||||
@@ -743,12 +763,16 @@ class StableDiffusion:
|
|||||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||||
):
|
):
|
||||||
merge_multiplier = 1.0
|
merge_multiplier = 1.0
|
||||||
|
flush()
|
||||||
# if using assistant, unfuse it
|
# if using assistant, unfuse it
|
||||||
if self.model_config.assistant_lora_path is not None:
|
if self.model_config.assistant_lora_path is not None:
|
||||||
print("Unloading asistant lora")
|
print("Unloading assistant lora")
|
||||||
# unfortunately, not an easier way with peft
|
if self.invert_assistant_lora:
|
||||||
self.assistant_lora.is_active = False
|
self.assistant_lora.is_active = True
|
||||||
|
# move weights on to the device
|
||||||
|
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
||||||
|
else:
|
||||||
|
self.assistant_lora.is_active = False
|
||||||
|
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
self.network.eval()
|
self.network.eval()
|
||||||
@@ -1257,9 +1281,13 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# refuse loras
|
# refuse loras
|
||||||
if self.model_config.assistant_lora_path is not None:
|
if self.model_config.assistant_lora_path is not None:
|
||||||
print("Loading asistant lora")
|
print("Loading assistant lora")
|
||||||
# unfortunately, not an easier way with peft
|
if self.invert_assistant_lora:
|
||||||
self.assistant_lora.is_active = True
|
self.assistant_lora.is_active = False
|
||||||
|
# move weights off the device
|
||||||
|
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
||||||
|
else:
|
||||||
|
self.assistant_lora.is_active = True
|
||||||
|
|
||||||
def get_latent_noise(
|
def get_latent_noise(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user