diff --git a/toolkit/models/wan21/wan_lora_convert.py b/toolkit/models/wan21/wan_lora_convert.py index 69fb1703..44001f17 100644 --- a/toolkit/models/wan21/wan_lora_convert.py +++ b/toolkit/models/wan21/wan_lora_convert.py @@ -17,6 +17,10 @@ def convert_to_diffusers(state_dict): for i, part in enumerate(parts): if part in ["q", "k", "v"]: parts[i] = f"to_{part}" + elif part == "k_img": + parts[i] = "add_k_proj" + elif part == "v_img": + parts[i] = "add_v_proj" elif part == "o": parts[i] = "to_out.0" new_key = ".".join(parts) @@ -54,6 +58,12 @@ def convert_to_original(state_dict): new_key = new_key.replace("to_k", "k") elif "to_v" in new_key: new_key = new_key.replace("to_v", "v") + + # img attn projection + elif "add_k_proj" in new_key: + new_key = new_key.replace("add_k_proj", "k_img") + elif "add_v_proj" in new_key: + new_key = new_key.replace("add_v_proj", "v_img") # FFN conversion if "ffn.net.0.proj" in new_key: