mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Convert wan lora weights on save to be something comfy can handle
This commit is contained in:
@@ -1611,6 +1611,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=self.sd.is_transformer,
|
||||
base_model=self.sd,
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -20,9 +20,13 @@ sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from networks.lora import LoRANetwork, get_block_index
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
@@ -179,6 +183,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
peft_format: bool = False,
|
||||
is_assistant_adapter: bool = False,
|
||||
is_transformer: bool = False,
|
||||
base_model: 'StableDiffusion' = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -204,6 +209,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
self.transformer_only = transformer_only
|
||||
self.base_model_ref = weakref.ref(base_model)
|
||||
|
||||
self.only_if_contains: Union[List, None] = only_if_contains
|
||||
|
||||
|
||||
@@ -1433,3 +1433,11 @@ class BaseModel:
|
||||
encoder.to(*args, **kwargs)
|
||||
else:
|
||||
self.text_encoder.to(*args, **kwargs)
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# can be overridden in child classes to convert weights before saving
|
||||
return state_dict
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# can be overridden in child classes to convert weights before loading
|
||||
return state_dict
|
||||
|
||||
1
toolkit/models/wan21/__init__.py
Normal file
1
toolkit/models/wan21/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .wan21 import Wan21
|
||||
@@ -42,6 +42,7 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
|
||||
|
||||
# for generation only?
|
||||
scheduler_configUniPC = {
|
||||
@@ -160,14 +161,14 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
# unload text encoder
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
|
||||
self.transformer.to(self._execution_device)
|
||||
|
||||
self.transformer.to(device)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
transformer_dtype)
|
||||
device, transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -198,7 +199,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
latent_model_input = latents.to(device, transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
@@ -468,6 +469,8 @@ class Wan21(BaseModel):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
@@ -481,6 +484,7 @@ class Wan21(BaseModel):
|
||||
):
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
# todo, figure out how to do video
|
||||
output = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
@@ -619,3 +623,9 @@ class Wan21(BaseModel):
|
||||
if noise is None:
|
||||
raise ValueError("Noise is not provided")
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
return convert_to_original(state_dict)
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
return convert_to_diffusers(state_dict)
|
||||
65
toolkit/models/wan21/wan_lora_convert.py
Normal file
65
toolkit/models/wan21/wan_lora_convert.py
Normal file
@@ -0,0 +1,65 @@
|
||||
def convert_to_diffusers(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
new_key = key
|
||||
# Base model name change
|
||||
if key.startswith("diffusion_model."):
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
|
||||
# Attention blocks conversion
|
||||
if "self_attn" in new_key:
|
||||
new_key = new_key.replace("self_attn", "attn1")
|
||||
elif "cross_attn" in new_key:
|
||||
new_key = new_key.replace("cross_attn", "attn2")
|
||||
|
||||
# Attention components conversion
|
||||
parts = new_key.split(".")
|
||||
for i, part in enumerate(parts):
|
||||
if part in ["q", "k", "v"]:
|
||||
parts[i] = f"to_{part}"
|
||||
elif part == "o":
|
||||
parts[i] = "to_out.0"
|
||||
new_key = ".".join(parts)
|
||||
|
||||
# FFN conversion
|
||||
if "ffn.0" in new_key:
|
||||
new_key = new_key.replace("ffn.0", "ffn.net.0.proj")
|
||||
elif "ffn.2" in new_key:
|
||||
new_key = new_key.replace("ffn.2", "ffn.net.2")
|
||||
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_to_original(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
new_key = key
|
||||
# Base model name change
|
||||
if key.startswith("transformer."):
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
|
||||
# Attention blocks conversion
|
||||
if "attn1" in new_key:
|
||||
new_key = new_key.replace("attn1", "self_attn")
|
||||
elif "attn2" in new_key:
|
||||
new_key = new_key.replace("attn2", "cross_attn")
|
||||
|
||||
# Attention components conversion
|
||||
if "to_out.0" in new_key:
|
||||
new_key = new_key.replace("to_out.0", "o")
|
||||
elif "to_q" in new_key:
|
||||
new_key = new_key.replace("to_q", "q")
|
||||
elif "to_k" in new_key:
|
||||
new_key = new_key.replace("to_k", "k")
|
||||
elif "to_v" in new_key:
|
||||
new_key = new_key.replace("to_v", "v")
|
||||
|
||||
# FFN conversion
|
||||
if "ffn.net.0.proj" in new_key:
|
||||
new_key = new_key.replace("ffn.net.0.proj", "ffn.0")
|
||||
elif "ffn.net.2" in new_key:
|
||||
new_key = new_key.replace("ffn.net.2", "ffn.2")
|
||||
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
@@ -554,6 +554,8 @@ class ToolkitNetworkMixin:
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict)
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
@@ -579,6 +581,8 @@ class ToolkitNetworkMixin:
|
||||
else:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd)
|
||||
|
||||
load_sd = OrderedDict()
|
||||
for key, value in weights_sd.items():
|
||||
|
||||
@@ -3061,3 +3061,11 @@ class StableDiffusion:
|
||||
encoder.to(*args, **kwargs)
|
||||
else:
|
||||
self.text_encoder.to(*args, **kwargs)
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# can be overridden in child classes to convert weights before saving
|
||||
return state_dict
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# can be overridden in child classes to convert weights before loading
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user