This commit is contained in:
lllyasviel
2024-01-25 05:47:54 -08:00
parent feac0d7f2d
commit 4337710c4a
2 changed files with 31 additions and 55 deletions

View File

@@ -12,7 +12,9 @@ import ldm_patched.modules.clip_vision
from omegaconf import OmegaConf
from modules.sd_models_config import find_checkpoint_config
from modules.shared import cmd_opts
import modules.sd_hijack as sd_hijack
from modules import sd_hijack
from modules.sd_hijack import EmbeddingsWithFixes
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config
@@ -163,14 +165,33 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
if typename == 'FrozenOpenCLIPEmbedder':
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack)
embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
elif typename == 'FrozenCLIPEmbedder':
embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack)
embedder = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
elif typename == 'FrozenOpenCLIPEmbedder2':
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack,
textual_inversion_key='clip_g')
embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
if len(text_cond_models) == 1:
@@ -180,17 +201,22 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder':
sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model,
sd_hijack.model_hijack)
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder':
sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model,
sd_hijack.model_hijack)
else:
raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)
timer.record("forge set components")
sd_hijack.model_hijack.hijack(sd_model)
timer.record("forge hijack")
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")