From e6739f7eb2fe8fd25950b756f1aecd89fa3abdc9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Mar 2025 12:55:11 -0700 Subject: [PATCH] Convert wan lora weights on save to be something comfy can handle --- jobs/process/BaseSDTrainProcess.py | 1 + toolkit/lora_special.py | 6 +++ toolkit/models/base_model.py | 8 +++ toolkit/models/wan21/__init__.py | 1 + toolkit/models/{ => wan21}/wan21.py | 20 ++++++-- toolkit/models/wan21/wan_lora_convert.py | 65 ++++++++++++++++++++++++ toolkit/network_mixins.py | 4 ++ toolkit/stable_diffusion_model.py | 8 +++ 8 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 toolkit/models/wan21/__init__.py rename toolkit/models/{ => wan21}/wan21.py (97%) create mode 100644 toolkit/models/wan21/wan_lora_convert.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 03ccc9cf..149c6668 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1611,6 +1611,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, is_transformer=self.sd.is_transformer, + base_model=self.sd, **network_kwargs ) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index c38feec8..4dd94b0f 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -20,9 +20,13 @@ sys.path.append(SD_SCRIPTS_ROOT) from networks.lora import LoRANetwork, get_block_index from toolkit.models.DoRA import DoRAModule +from typing import TYPE_CHECKING from torch.utils.checkpoint import checkpoint +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -179,6 +183,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): peft_format: bool = False, is_assistant_adapter: bool = False, is_transformer: bool = False, + base_model: 'StableDiffusion' = None, **kwargs ) -> None: """ @@ -204,6 +209,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): ignore_if_contains = [] self.ignore_if_contains = ignore_if_contains self.transformer_only = transformer_only + self.base_model_ref = weakref.ref(base_model) self.only_if_contains: Union[List, None] = only_if_contains diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index c9d277a0..56c506da 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1433,3 +1433,11 @@ class BaseModel: encoder.to(*args, **kwargs) else: self.text_encoder.to(*args, **kwargs) + + def convert_lora_weights_before_save(self, state_dict): + # can be overridden in child classes to convert weights before saving + return state_dict + + def convert_lora_weights_before_load(self, state_dict): + # can be overridden in child classes to convert weights before loading + return state_dict diff --git a/toolkit/models/wan21/__init__.py b/toolkit/models/wan21/__init__.py new file mode 100644 index 00000000..9e2aa3ca --- /dev/null +++ b/toolkit/models/wan21/__init__.py @@ -0,0 +1 @@ +from .wan21 import Wan21 \ No newline at end of file diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21/wan21.py similarity index 97% rename from toolkit/models/wan21.py rename to toolkit/models/wan21/wan21.py index af1ede2b..105191c3 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -42,6 +42,7 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE # from ...callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from typing import Any, Callable, Dict, List, Optional, Union +from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original # for generation only? scheduler_configUniPC = { @@ -160,14 +161,14 @@ class AggressiveWanUnloadPipeline(WanPipeline): # unload text encoder print("Unloading text encoder") self.text_encoder.to("cpu") - - self.transformer.to(self._execution_device) + + self.transformer.to(device) transformer_dtype = self.transformer.dtype - prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_embeds = prompt_embeds.to(device, transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to( - transformer_dtype) + device, transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -198,7 +199,7 @@ class AggressiveWanUnloadPipeline(WanPipeline): continue self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) + latent_model_input = latents.to(device, transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( @@ -468,6 +469,8 @@ class Wan21(BaseModel): scheduler=scheduler, ) + pipeline = pipeline.to(self.device_torch) + return pipeline def generate_single_image( @@ -481,6 +484,7 @@ class Wan21(BaseModel): ): # reactivate progress bar since this is slooooow pipeline.set_progress_bar_config(disable=False) + pipeline = pipeline.to(self.device_torch) # todo, figure out how to do video output = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( @@ -619,3 +623,9 @@ class Wan21(BaseModel): if noise is None: raise ValueError("Noise is not provided") return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + return convert_to_original(state_dict) + + def convert_lora_weights_before_load(self, state_dict): + return convert_to_diffusers(state_dict) diff --git a/toolkit/models/wan21/wan_lora_convert.py b/toolkit/models/wan21/wan_lora_convert.py new file mode 100644 index 00000000..69fb1703 --- /dev/null +++ b/toolkit/models/wan21/wan_lora_convert.py @@ -0,0 +1,65 @@ +def convert_to_diffusers(state_dict): + new_state_dict = {} + for key in state_dict: + new_key = key + # Base model name change + if key.startswith("diffusion_model."): + new_key = key.replace("diffusion_model.", "transformer.") + + # Attention blocks conversion + if "self_attn" in new_key: + new_key = new_key.replace("self_attn", "attn1") + elif "cross_attn" in new_key: + new_key = new_key.replace("cross_attn", "attn2") + + # Attention components conversion + parts = new_key.split(".") + for i, part in enumerate(parts): + if part in ["q", "k", "v"]: + parts[i] = f"to_{part}" + elif part == "o": + parts[i] = "to_out.0" + new_key = ".".join(parts) + + # FFN conversion + if "ffn.0" in new_key: + new_key = new_key.replace("ffn.0", "ffn.net.0.proj") + elif "ffn.2" in new_key: + new_key = new_key.replace("ffn.2", "ffn.net.2") + + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + +def convert_to_original(state_dict): + new_state_dict = {} + for key in state_dict: + new_key = key + # Base model name change + if key.startswith("transformer."): + new_key = key.replace("transformer.", "diffusion_model.") + + # Attention blocks conversion + if "attn1" in new_key: + new_key = new_key.replace("attn1", "self_attn") + elif "attn2" in new_key: + new_key = new_key.replace("attn2", "cross_attn") + + # Attention components conversion + if "to_out.0" in new_key: + new_key = new_key.replace("to_out.0", "o") + elif "to_q" in new_key: + new_key = new_key.replace("to_q", "q") + elif "to_k" in new_key: + new_key = new_key.replace("to_k", "k") + elif "to_v" in new_key: + new_key = new_key.replace("to_v", "v") + + # FFN conversion + if "ffn.net.0.proj" in new_key: + new_key = new_key.replace("ffn.net.0.proj", "ffn.0") + elif "ffn.net.2" in new_key: + new_key = new_key.replace("ffn.net.2", "ffn.2") + + new_state_dict[new_key] = state_dict[key] + return new_state_dict diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index d2d1a500..6346eacd 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -554,6 +554,8 @@ class ToolkitNetworkMixin: new_save_dict[new_key] = value save_dict = new_save_dict + + save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) if metadata is None: metadata = OrderedDict() @@ -579,6 +581,8 @@ class ToolkitNetworkMixin: else: # probably a state dict weights_sd = file + + weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd) load_sd = OrderedDict() for key, value in weights_sd.items(): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 65736178..892385b6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -3061,3 +3061,11 @@ class StableDiffusion: encoder.to(*args, **kwargs) else: self.text_encoder.to(*args, **kwargs) + + def convert_lora_weights_before_save(self, state_dict): + # can be overridden in child classes to convert weights before saving + return state_dict + + def convert_lora_weights_before_load(self, state_dict): + # can be overridden in child classes to convert weights before loading + return state_dict