mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Hotfix to handle latest transformers clip model missing key suddenly
This commit is contained in:
@@ -892,6 +892,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("cond_stage_model.transformer"):
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
||||||
|
# support checkpoint without position_ids (invalid checkpoint)
|
||||||
|
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||||
|
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||||
return text_model_dict
|
return text_model_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -1257,6 +1260,10 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||||
logging.set_verbosity_warning()
|
logging.set_verbosity_warning()
|
||||||
|
|
||||||
|
# latest transformers doesnt have position ids. Do we remove it?
|
||||||
|
if "text_model.embeddings.position_ids" not in text_model.state_dict():
|
||||||
|
del converted_text_encoder_checkpoint["text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
print("loading text encoder:", info)
|
print("loading text encoder:", info)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user