diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 51724130..4274ffe4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -184,6 +184,9 @@ class StableDiffusion: self.quantize_device = quantize_device if quantize_device is not None else self.device self.low_vram = self.model_config.low_vram + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + def load_model(self): if self.is_loaded: return @@ -493,6 +496,18 @@ class StableDiffusion: transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() + if self.model_config.assistant_lora_path is not None and self.model_config.lora_path: + raise ValueError("Cannot load both assistant lora and lora at the same time") + + if self.model_config.assistant_lora_path is not None and self.is_flux: + # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on + # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps + # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half + # so we will merge in now and sample with -1 weight later + self.invert_assistant_lora = True + # trigger it to get merged in + self.model_config.lora_path = self.model_config.assistant_lora_path + if self.model_config.lora_path is not None: # need the pipe to do this unfortunately for now # we have to fuse in the weights before quantizing @@ -677,6 +692,11 @@ class StableDiffusion: self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.assistant_lora_path, self) + if self.invert_assistant_lora: + # invert and disable during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + if self.is_pixart and self.vae_scale_factor == 16: # TODO make our own pipeline? # we generate an image 2x larger, so we need to copy the sizes from larger ones down @@ -743,12 +763,16 @@ class StableDiffusion: pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): merge_multiplier = 1.0 - + flush() # if using assistant, unfuse it if self.model_config.assistant_lora_path is not None: - print("Unloading asistant lora") - # unfortunately, not an easier way with peft - self.assistant_lora.is_active = False + print("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False if self.network is not None: self.network.eval() @@ -1257,9 +1281,13 @@ class StableDiffusion: # refuse loras if self.model_config.assistant_lora_path is not None: - print("Loading asistant lora") - # unfortunately, not an easier way with peft - self.assistant_lora.is_active = True + print("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True def get_latent_noise( self,