diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 475c4845..e6c1c28b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -379,6 +379,8 @@ class ModelConfig: self._original_refiner_name_or_path = self.refiner_name_or_path self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) self.lora_path = kwargs.get('lora_path', None) + # mainly for decompression loras for distilled models + self.assistant_lora_path = kwargs.get('assistant_lora_path', None) self.latent_space_version = kwargs.get('latent_space_version', None) # only for SDXL models for now diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index a3d369bc..92473dee 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -124,6 +124,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" + PEFT_PREFIX_UNET = "unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER @@ -171,6 +172,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): network_type: str = "lora", full_train_in_out: bool = False, transformer_only: bool = False, + peft_format: bool = False, **kwargs ) -> None: """ @@ -223,6 +225,17 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.module_class = DoRAModule module_class = DoRAModule + self.peft_format = peft_format + + # always do peft for flux only for now + if self.is_flux: + self.peft_format = True + + if self.peft_format: + # no alpha for peft + self.alpha = self.lora_dim + self.conv_alpha = self.conv_lora_dim + self.full_train_in_out = full_train_in_out if modules_dim is not None: @@ -252,8 +265,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: unet_prefix = self.LORA_PREFIX_UNET + if self.peft_format: + unet_prefix = self.PEFT_PREFIX_UNET if is_pixart or is_v3 or is_auraflow or is_flux: unet_prefix = f"lora_transformer" + if self.peft_format: + unet_prefix = "transformer" prefix = ( unet_prefix @@ -282,7 +299,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): lora_name = ".".join(lora_name) # if it doesnt have a name, it wil have two dots lora_name.replace("..", ".") - lora_name = lora_name.replace(".", "_") + if self.peft_format: + # we replace this on saving + lora_name = lora_name.replace(".", "$$") + else: + lora_name = lora_name.replace(".", "_") + skip = False if any([word in child_name for word in self.ignore_if_contains]): diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 2df15b82..94781e0c 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -204,7 +204,6 @@ class ToolkitModuleMixin: return lx * scale - def lorm_forward(self: Network, x, *args, **kwargs): network: Network = self.network_ref() if not network.is_active: @@ -492,6 +491,24 @@ class ToolkitNetworkMixin: v = v.detach().clone().to("cpu").to(dtype) save_dict[key] = v + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + + new_save_dict = {} + for key, value in save_dict.items(): + if key.endswith('.alpha'): + continue + new_key = key + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + # replace all $$ with . + new_key = new_key.replace('$$', '.') + new_save_dict[new_key] = value + + save_dict = new_save_dict + if metadata is None: metadata = OrderedDict() metadata = add_model_hash_to_meta(state_dict, metadata) @@ -519,6 +536,20 @@ class ToolkitNetworkMixin: # replace old double __ with single _ if self.is_pixart: load_key = load_key.replace('__', '_') + + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + if load_key.endswith('.alpha'): + continue + load_key = load_key.replace('lora_A', 'lora_down') + load_key = load_key.replace('lora_B', 'lora_up') + # replace all . with $$ + load_key = load_key.replace('.', '$$') + load_key = load_key.replace('$$lora_down$$', '.lora_down.') + load_key = load_key.replace('$$lora_up$$', '.lora_up.') + load_sd[load_key] = value # extract extra items from state dict @@ -533,7 +564,8 @@ class ToolkitNetworkMixin: del load_sd[key] print(f"Missing keys: {to_delete}") - if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete): + if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not ( + len(to_delete) == 1 and 'emb_params' in to_delete): print(" Attempting to load with forced keymap") return self.load_weights(file, force_weight_mapping=True) @@ -657,4 +689,3 @@ class ToolkitNetworkMixin: params_reduced += (num_orig_module_params - num_lorem_params) return params_reduced - diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f3426714..b698eae6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -616,7 +616,17 @@ class StableDiffusion: if self.model_config.lora_path is not None: pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.fuse_lora() - self.unet.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + + if self.model_config.assistant_lora_path is not None: + if self.model_config.lora_path is not None: + raise ValueError("Cannot have both lora and assistant lora") + print("Loading assistant lora") + pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") + pipe.fuse_lora(lora_scale=1.0) + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() self.tokenizer = tokenizer self.text_encoder = text_encoder @@ -690,7 +700,15 @@ class StableDiffusion: pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): merge_multiplier = 1.0 - # sample_folder = os.path.join(self.save_root, 'samples') + + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print("Unloading asistant lora") + # unfortunately, not an easier way with peft + self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") + self.pipeline.fuse_lora(lora_scale=-1.0) + self.pipeline.unload_lora_weights() + if self.network is not None: self.network.eval() network = self.network @@ -1162,6 +1180,14 @@ class StableDiffusion: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) + # refuse loras + if self.model_config.assistant_lora_path is not None: + print("Loading asistant lora") + # unfortunately, not an easier way with peft + self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") + self.pipeline.fuse_lora(lora_scale=1.0) + self.pipeline.unload_lora_weights() + def get_latent_noise( self, height=None,