Restore lines for DoRA TE keys fix (#2240)

This commit is contained in:
Lucas Freire Sangoi
2024-11-06 17:20:57 -03:00
committed by GitHub
parent 329c3ca334
commit 75120d02f3

View File

@@ -234,32 +234,32 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
# for k in sdk:
# if k.endswith(".weight"):
# if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
# l_key = k[len("t5xxl.transformer."):-len(".weight")]
# lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
# key_map[lora_key] = k
#
# #####
# lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
# key_map[lora_key] = k
# #####
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
#####
lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
key_map[lora_key] = k
#####
# elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
# l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
# lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
# key_map[lora_key] = k
#
#
# k = "clip_g.transformer.text_projection.weight"
# if k in sdk:
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
# # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
#
# k = "clip_l.transformer.text_projection.weight"
# if k in sdk:
# key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
k = "clip_l.transformer.text_projection.weight"
if k in sdk:
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
return sdk, key_map