Update sd_hijack.py

This commit is contained in:
lllyasviel
2024-01-25 05:04:27 -08:00
parent 6ff68b4aa4
commit 231b860e92

View File

@@ -185,11 +185,6 @@ class StableDiffusionModelHijack:
else:
m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
@@ -230,39 +225,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model')
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
undo_optimizations()
undo_weighted_forward(m)
self.apply_circular(False)
self.layers = None
self.clip = None
pass
def apply_circular(self, enable):
@@ -287,8 +250,7 @@ class StableDiffusionModelHijack:
return token_count, self.clip.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
self.undo_hijack(m)
self.hijack(m)
pass
class EmbeddingsWithFixes(torch.nn.Module):