mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fix naming with wan i2v new keys in lora
This commit is contained in:
@@ -17,6 +17,10 @@ def convert_to_diffusers(state_dict):
|
|||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
if part in ["q", "k", "v"]:
|
if part in ["q", "k", "v"]:
|
||||||
parts[i] = f"to_{part}"
|
parts[i] = f"to_{part}"
|
||||||
|
elif part == "k_img":
|
||||||
|
parts[i] = "add_k_proj"
|
||||||
|
elif part == "v_img":
|
||||||
|
parts[i] = "add_v_proj"
|
||||||
elif part == "o":
|
elif part == "o":
|
||||||
parts[i] = "to_out.0"
|
parts[i] = "to_out.0"
|
||||||
new_key = ".".join(parts)
|
new_key = ".".join(parts)
|
||||||
@@ -54,6 +58,12 @@ def convert_to_original(state_dict):
|
|||||||
new_key = new_key.replace("to_k", "k")
|
new_key = new_key.replace("to_k", "k")
|
||||||
elif "to_v" in new_key:
|
elif "to_v" in new_key:
|
||||||
new_key = new_key.replace("to_v", "v")
|
new_key = new_key.replace("to_v", "v")
|
||||||
|
|
||||||
|
# img attn projection
|
||||||
|
elif "add_k_proj" in new_key:
|
||||||
|
new_key = new_key.replace("add_k_proj", "k_img")
|
||||||
|
elif "add_v_proj" in new_key:
|
||||||
|
new_key = new_key.replace("add_v_proj", "v_img")
|
||||||
|
|
||||||
# FFN conversion
|
# FFN conversion
|
||||||
if "ffn.net.0.proj" in new_key:
|
if "ffn.net.0.proj" in new_key:
|
||||||
|
|||||||
Reference in New Issue
Block a user