diff --git a/toolkit/kohya_model_util.py b/toolkit/kohya_model_util.py index 5e976f0b..798fc2dc 100644 --- a/toolkit/kohya_model_util.py +++ b/toolkit/kohya_model_util.py @@ -892,6 +892,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text return text_model_dict @@ -1257,6 +1260,10 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) logging.set_verbosity_warning() + # latest transformers doesnt have position ids. Do we remove it? + if "text_model.embeddings.position_ids" not in text_model.state_dict(): + del converted_text_encoder_checkpoint["text_model.embeddings.position_ids"] + info = text_model.load_state_dict(converted_text_encoder_checkpoint) print("loading text encoder:", info)