diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f048e990..aefb7704 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -142,57 +142,7 @@ class StableDiffusionModelHijack: pass def hijack(self, m): - conditioner = getattr(m, 'conditioner', None) - if conditioner: - text_cond_models = [] - - for i in range(len(conditioner.embedders)): - embedder = conditioner.embedders[i] - typename = type(embedder).__name__ - if typename == 'FrozenOpenCLIPEmbedder': - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - text_cond_models.append(conditioner.embedders[i]) - if typename == 'FrozenCLIPEmbedder': - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - text_cond_models.append(conditioner.embedders[i]) - if typename == 'FrozenOpenCLIPEmbedder2': - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self, textual_inversion_key='clip_g') - conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - text_cond_models.append(conditioner.embedders[i]) - - if len(text_cond_models) == 1: - m.cond_stage_model = text_cond_models[0] - else: - m.cond_stage_model = conditioner - - elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - - elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - - self.clip = m.cond_stage_model - - apply_weighted_forward(m) - - def flatten(el): - flattened = [flatten(children) for children in el.children()] - res = [el] - for c in flattened: - res += c - return res - - self.layers = flatten(m) - sd_unet.original_forward = None + pass def undo_hijack(self, m): pass diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index fc124be1..65dc2045 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -12,7 +12,9 @@ import ldm_patched.modules.clip_vision from omegaconf import OmegaConf from modules.sd_models_config import find_checkpoint_config from modules.shared import cmd_opts -import modules.sd_hijack as sd_hijack +from modules import sd_hijack +from modules.sd_hijack import EmbeddingsWithFixes +from modules import sd_hijack_clip, sd_hijack_open_clip from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config @@ -163,14 +165,33 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): if typename == 'FrozenOpenCLIPEmbedder': embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer + model_embeddings = embedder.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, + sd_hijack.model_hijack) + embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, sd_hijack.model_hijack) + + conditioner.embedders[i] = embedder text_cond_models.append(embedder) elif typename == 'FrozenCLIPEmbedder': embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer + model_embeddings = embedder.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, + sd_hijack.model_hijack) + embedder = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, sd_hijack.model_hijack) + + conditioner.embedders[i] = embedder text_cond_models.append(embedder) elif typename == 'FrozenOpenCLIPEmbedder2': embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer + model_embeddings = embedder.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, + sd_hijack.model_hijack, + textual_inversion_key='clip_g') + embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, sd_hijack.model_hijack) + + conditioner.embedders[i] = embedder text_cond_models.append(embedder) if len(text_cond_models) == 1: @@ -180,17 +201,22 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_l.transformer + model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack) + sd_model.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model, + sd_hijack.model_hijack) elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer + model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack) + sd_model.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model, + sd_hijack.model_hijack) else: raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__) timer.record("forge set components") - sd_hijack.model_hijack.hijack(sd_model) - timer.record("forge hijack") - sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash")