diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3060bb5a..bdb36fa0 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -410,6 +410,7 @@ class ModelConfig: self.lora_path = kwargs.get('lora_path', None) # mainly for decompression loras for distilled models self.assistant_lora_path = kwargs.get('assistant_lora_path', None) + self.inference_lora_path = kwargs.get('inference_lora_path', None) self.latent_space_version = kwargs.get('latent_space_version', None) # only for SDXL models for now diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6dba69b3..541822bd 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -497,33 +497,46 @@ class StableDiffusion: transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() - if self.model_config.assistant_lora_path is not None: + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: + raise ValueError("Cannot load both assistant lora and inference lora at the same time") + if self.model_config.lora_path: raise ValueError("Cannot load both assistant lora and lora at the same time") if not self.is_flux: - raise ValueError("Assistant lora is only supported for flux models currently") + raise ValueError("Assistant/ inference lora is only supported for flux models currently") + + load_lora_path = self.model_config.inference_lora_path + if load_lora_path is None: + load_lora_path = self.model_config.assistant_lora_path - if os.path.isdir(self.model_config.assistant_lora_path): - self.model_config.assistant_lora_path = os.path.join( - self.model_config.assistant_lora_path, "pytorch_lora_weights.safetensors" + if os.path.isdir(load_lora_path): + load_lora_path = os.path.join( + load_lora_path, "pytorch_lora_weights.safetensors" ) - elif not os.path.exists(self.model_config.assistant_lora_path): - print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}") + elif not os.path.exists(load_lora_path): + print(f"Grabbing lora from the hub: {load_lora_path}") new_lora_path = hf_hub_download( - self.model_config.assistant_lora_path, + load_lora_path, filename="pytorch_lora_weights.safetensors" ) # replace the path - self.model_config.assistant_lora_path = new_lora_path + load_lora_path = new_lora_path + + if self.model_config.inference_lora_path is not None: + self.model_config.inference_lora_path = new_lora_path + if self.model_config.assistant_lora_path is not None: + self.model_config.assistant_lora_path = new_lora_path - # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on - # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps - # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half - # so we will merge in now and sample with -1 weight later - self.invert_assistant_lora = True - # trigger it to get merged in - self.model_config.lora_path = self.model_config.assistant_lora_path + if self.model_config.assistant_lora_path is not None: + # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on + # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps + # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half + # so we will merge in now and sample with -1 weight later + self.invert_assistant_lora = True + # trigger it to get merged in + self.model_config.lora_path = self.model_config.assistant_lora_path if self.model_config.lora_path is not None: print("Fusing in LoRA") @@ -763,6 +776,13 @@ class StableDiffusion: # invert and disable during training self.assistant_lora.multiplier = -1.0 self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.inference_lora_path, self) + # disable during training + self.assistant_lora.is_active = False if self.is_pixart and self.vae_scale_factor == 16: # TODO make our own pipeline? @@ -840,6 +860,12 @@ class StableDiffusion: self.assistant_lora.force_to(self.device_torch, self.torch_dtype) else: self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) if self.network is not None: self.network.eval() @@ -1356,6 +1382,12 @@ class StableDiffusion: self.assistant_lora.force_to('cpu', self.torch_dtype) else: self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) flush()