Convert wan lora weights on save to be something comfy can handle

This commit is contained in:
Jaret Burkett
2025-03-08 12:55:11 -07:00
parent 7e37918fbc
commit e6739f7eb2
8 changed files with 108 additions and 5 deletions

View File

@@ -20,9 +20,13 @@ sys.path.append(SD_SCRIPTS_ROOT)
from networks.lora import LoRANetwork, get_block_index
from toolkit.models.DoRA import DoRAModule
from typing import TYPE_CHECKING
from torch.utils.checkpoint import checkpoint
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -179,6 +183,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
peft_format: bool = False,
is_assistant_adapter: bool = False,
is_transformer: bool = False,
base_model: 'StableDiffusion' = None,
**kwargs
) -> None:
"""
@@ -204,6 +209,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
ignore_if_contains = []
self.ignore_if_contains = ignore_if_contains
self.transformer_only = transformer_only
self.base_model_ref = weakref.ref(base_model)
self.only_if_contains: Union[List, None] = only_if_contains

View File

@@ -1433,3 +1433,11 @@ class BaseModel:
encoder.to(*args, **kwargs)
else:
self.text_encoder.to(*args, **kwargs)
def convert_lora_weights_before_save(self, state_dict):
# can be overridden in child classes to convert weights before saving
return state_dict
def convert_lora_weights_before_load(self, state_dict):
# can be overridden in child classes to convert weights before loading
return state_dict

View File

@@ -0,0 +1 @@
from .wan21 import Wan21

View File

@@ -42,6 +42,7 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
# for generation only?
scheduler_configUniPC = {
@@ -160,14 +161,14 @@ class AggressiveWanUnloadPipeline(WanPipeline):
# unload text encoder
print("Unloading text encoder")
self.text_encoder.to("cpu")
self.transformer.to(self._execution_device)
self.transformer.to(device)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
transformer_dtype)
device, transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -198,7 +199,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
latent_model_input = latents.to(device, transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
@@ -468,6 +469,8 @@ class Wan21(BaseModel):
scheduler=scheduler,
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
@@ -481,6 +484,7 @@ class Wan21(BaseModel):
):
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
pipeline = pipeline.to(self.device_torch)
# todo, figure out how to do video
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
@@ -619,3 +623,9 @@ class Wan21(BaseModel):
if noise is None:
raise ValueError("Noise is not provided")
return (noise - batch.latents).detach()
def convert_lora_weights_before_save(self, state_dict):
return convert_to_original(state_dict)
def convert_lora_weights_before_load(self, state_dict):
return convert_to_diffusers(state_dict)

View File

@@ -0,0 +1,65 @@
def convert_to_diffusers(state_dict):
new_state_dict = {}
for key in state_dict:
new_key = key
# Base model name change
if key.startswith("diffusion_model."):
new_key = key.replace("diffusion_model.", "transformer.")
# Attention blocks conversion
if "self_attn" in new_key:
new_key = new_key.replace("self_attn", "attn1")
elif "cross_attn" in new_key:
new_key = new_key.replace("cross_attn", "attn2")
# Attention components conversion
parts = new_key.split(".")
for i, part in enumerate(parts):
if part in ["q", "k", "v"]:
parts[i] = f"to_{part}"
elif part == "o":
parts[i] = "to_out.0"
new_key = ".".join(parts)
# FFN conversion
if "ffn.0" in new_key:
new_key = new_key.replace("ffn.0", "ffn.net.0.proj")
elif "ffn.2" in new_key:
new_key = new_key.replace("ffn.2", "ffn.net.2")
new_state_dict[new_key] = state_dict[key]
return new_state_dict
def convert_to_original(state_dict):
new_state_dict = {}
for key in state_dict:
new_key = key
# Base model name change
if key.startswith("transformer."):
new_key = key.replace("transformer.", "diffusion_model.")
# Attention blocks conversion
if "attn1" in new_key:
new_key = new_key.replace("attn1", "self_attn")
elif "attn2" in new_key:
new_key = new_key.replace("attn2", "cross_attn")
# Attention components conversion
if "to_out.0" in new_key:
new_key = new_key.replace("to_out.0", "o")
elif "to_q" in new_key:
new_key = new_key.replace("to_q", "q")
elif "to_k" in new_key:
new_key = new_key.replace("to_k", "k")
elif "to_v" in new_key:
new_key = new_key.replace("to_v", "v")
# FFN conversion
if "ffn.net.0.proj" in new_key:
new_key = new_key.replace("ffn.net.0.proj", "ffn.0")
elif "ffn.net.2" in new_key:
new_key = new_key.replace("ffn.net.2", "ffn.2")
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -554,6 +554,8 @@ class ToolkitNetworkMixin:
new_save_dict[new_key] = value
save_dict = new_save_dict
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict)
if metadata is None:
metadata = OrderedDict()
@@ -579,6 +581,8 @@ class ToolkitNetworkMixin:
else:
# probably a state dict
weights_sd = file
weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd)
load_sd = OrderedDict()
for key, value in weights_sd.items():

View File

@@ -3061,3 +3061,11 @@ class StableDiffusion:
encoder.to(*args, **kwargs)
else:
self.text_encoder.to(*args, **kwargs)
def convert_lora_weights_before_save(self, state_dict):
# can be overridden in child classes to convert weights before saving
return state_dict
def convert_lora_weights_before_load(self, state_dict):
# can be overridden in child classes to convert weights before loading
return state_dict