diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index cf3fdf17..54f01954 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -16,6 +16,7 @@ from modules import sd_hijack from modules import sd_hijack_clip, sd_hijack_open_clip from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config +from modules_forge import forge_clip import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -167,7 +168,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) - embedder = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, sd_hijack.model_hijack) + embedder = forge_clip.CLIP_SD_XL_L(embedder, sd_hijack.model_hijack) conditioner.embedders[i] = embedder text_cond_models.append(embedder) @@ -177,7 +178,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack, textual_inversion_key='clip_g') - embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, sd_hijack.model_hijack) + embedder = forge_clip.CLIP_SD_XL_G(embedder, sd_hijack.model_hijack) conditioner.embedders[i] = embedder text_cond_models.append(embedder) @@ -192,16 +193,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) - sd_model.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model, - sd_hijack.model_hijack) + sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack) elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) - sd_model.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model, - sd_hijack.model_hijack) + sd_model.cond_stage_model = forge_clip.CLIP_SD_21_G(sd_model.cond_stage_model, sd_hijack.model_hijack) else: raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)