mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
Text Processing Engine is Finished
100% reproduce all previous results, including TI embeddings, LoRAs in CLIP, emphasize settings, BREAK, timestep swap scheduling, AB mixture, advanced uncond, etc Backend is 85% finished
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
WebUI Forge is under a week of major revision right now between 2024 Aug 1 and Aug 7. To join the test, just update to the latest unstable version.
|
WebUI Forge is under a week of major revision right now between 2024 Aug 1 and Aug 7. To join the test, just update to the latest unstable version.
|
||||||
|
|
||||||
**Current Progress (2024 Aug 3):** Backend Rewrite is 81% finished - remaining 30 hours to begin making it stable; remaining 48 hours to begin supporting many new things.
|
**Current Progress (2024 Aug 3):** Backend Rewrite is 85% finished - remaining 30 hours to begin making it stable; remaining 48 hours to begin supporting many new things.
|
||||||
|
|
||||||
For downloading previous versions, see [Previous Versions](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/849).
|
For downloading previous versions, see [Previous Versions](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/849).
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ import torch
|
|||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from backend.text_processing import parsing, emphasis
|
from backend.text_processing import parsing, emphasis
|
||||||
from textual_inversion import EmbeddingDatabase
|
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||||
|
|
||||||
|
|
||||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
|
last_extra_generation_params = {}
|
||||||
|
|
||||||
|
|
||||||
class PromptChunk:
|
class PromptChunk:
|
||||||
@@ -37,6 +38,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
|||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
emb = emb.to(inputs_embeds)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||||
|
|
||||||
@@ -45,8 +47,11 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
|||||||
return torch.stack(vecs)
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
class ClassicTextProcessingEngine:
|
class ClassicTextProcessingEngine(torch.nn.Module):
|
||||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None):
|
def __init__(self, text_encoder, tokenizer, chunk_length=75,
|
||||||
|
embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
|
||||||
|
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
|
||||||
|
callback_before_encode=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||||
@@ -56,20 +61,21 @@ class ClassicTextProcessingEngine:
|
|||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.emphasis = emphasis.get_current_option(emphasis_name)
|
self.emphasis = emphasis.get_current_option(emphasis_name)()
|
||||||
self.text_projection = text_projection
|
self.text_projection = text_projection
|
||||||
self.minimal_clip_skip = minimal_clip_skip
|
self.minimal_clip_skip = minimal_clip_skip
|
||||||
self.clip_skip = clip_skip
|
self.clip_skip = clip_skip
|
||||||
self.return_pooled = return_pooled
|
self.return_pooled = return_pooled
|
||||||
|
self.final_layer_norm = final_layer_norm
|
||||||
self.callback_before_encode = callback_before_encode
|
self.callback_before_encode = callback_before_encode
|
||||||
|
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
|
|
||||||
self.id_start = self.tokenizer.bos_token_id
|
self.id_start = self.tokenizer.bos_token_id
|
||||||
self.id_end = self.tokenizer.eos_token_id
|
self.id_end = self.tokenizer.eos_token_id
|
||||||
self.id_pad = self.id_end
|
self.id_pad = self.tokenizer.pad_token_id
|
||||||
|
|
||||||
model_embeddings = text_encoder.text_model.embeddings
|
model_embeddings = text_encoder.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||||
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
@@ -94,9 +100,6 @@ class ClassicTextProcessingEngine:
|
|||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
# # Todo: remove these
|
|
||||||
# self.legacy_ucg_val = None # for sgm codebase
|
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
chunk = PromptChunk()
|
chunk = PromptChunk()
|
||||||
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||||
@@ -112,27 +115,25 @@ class ClassicTextProcessingEngine:
|
|||||||
return tokenized
|
return tokenized
|
||||||
|
|
||||||
def encode_with_transformers(self, tokens):
|
def encode_with_transformers(self, tokens):
|
||||||
self.text_encoder.transformer.text_model.embeddings.to(tokens.device)
|
tokens = tokens.to(self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device)
|
||||||
|
|
||||||
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
|
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
|
||||||
|
|
||||||
layer_id = - max(self.clip_skip, self.minimal_clip_skip)
|
layer_id = - max(self.clip_skip, self.minimal_clip_skip)
|
||||||
z = outputs.hidden_states[layer_id]
|
z = outputs.hidden_states[layer_id]
|
||||||
|
|
||||||
|
if self.final_layer_norm:
|
||||||
|
z = self.text_encoder.transformer.text_model.final_layer_norm(z)
|
||||||
|
|
||||||
if self.return_pooled:
|
if self.return_pooled:
|
||||||
pooled_output = outputs.pooler_output
|
pooled_output = outputs.pooler_output
|
||||||
|
|
||||||
if self.text_projection:
|
if self.text_projection:
|
||||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
|
||||||
|
|
||||||
z.pooled = pooled_output
|
z.pooled = pooled_output
|
||||||
return z
|
return z
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
|
||||||
embedding_layer = self.text_encoder.transformer.text_model.embeddings
|
|
||||||
ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
|
||||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
|
||||||
return embedded
|
|
||||||
|
|
||||||
def tokenize_line(self, line):
|
def tokenize_line(self, line):
|
||||||
parsed = parsing.parse_prompt_attention(line)
|
parsed = parsing.parse_prompt_attention(line)
|
||||||
|
|
||||||
@@ -235,9 +236,9 @@ class ClassicTextProcessingEngine:
|
|||||||
|
|
||||||
return batch_chunks, token_count
|
return batch_chunks, token_count
|
||||||
|
|
||||||
def __call__(self, texts):
|
def forward(self, texts):
|
||||||
if self.callback_before_encode is not None:
|
if self.callback_before_encode is not None:
|
||||||
self.callback_before_encode()
|
self.callback_before_encode(self, texts)
|
||||||
|
|
||||||
batch_chunks, token_count = self.process_texts(texts)
|
batch_chunks, token_count = self.process_texts(texts)
|
||||||
|
|
||||||
@@ -259,28 +260,21 @@ class ClassicTextProcessingEngine:
|
|||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
zs.append(z)
|
zs.append(z)
|
||||||
|
|
||||||
|
global last_extra_generation_params
|
||||||
|
|
||||||
|
last_extra_generation_params = {}
|
||||||
|
|
||||||
if used_embeddings:
|
if used_embeddings:
|
||||||
|
names = []
|
||||||
|
|
||||||
for name, embedding in used_embeddings.items():
|
for name, embedding in used_embeddings.items():
|
||||||
print(f'Used Embedding: {name}')
|
print(f'Used Embedding: {name}')
|
||||||
|
names.append(name.replace(":", "").replace(",", ""))
|
||||||
|
|
||||||
# Todo:
|
last_extra_generation_params["TI"] = ", ".join(names)
|
||||||
# if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
|
||||||
# hashes = []
|
if any(x for x in texts if "(" in x or "[" in x) and self.emphasis.name != "Original":
|
||||||
# for name, embedding in used_embeddings.items():
|
last_extra_generation_params["Emphasis"] = self.emphasis.name
|
||||||
# shorthash = embedding.shorthash
|
|
||||||
# if not shorthash:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# name = name.replace(":", "").replace(",", "")
|
|
||||||
# hashes.append(f"{name}: {shorthash}")
|
|
||||||
#
|
|
||||||
# if hashes:
|
|
||||||
# if self.hijack.extra_generation_params.get("TI hashes"):
|
|
||||||
# hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
|
||||||
# self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
|
||||||
#
|
|
||||||
# if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
|
||||||
# self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
|
||||||
|
|
||||||
if self.return_pooled:
|
if self.return_pooled:
|
||||||
return torch.hstack(zs), zs[0].pooled
|
return torch.hstack(zs), zs[0].pooled
|
||||||
@@ -300,7 +294,7 @@ class ClassicTextProcessingEngine:
|
|||||||
pooled = getattr(z, 'pooled', None)
|
pooled = getattr(z, 'pooled', None)
|
||||||
|
|
||||||
self.emphasis.tokens = remade_batch_tokens
|
self.emphasis.tokens = remade_batch_tokens
|
||||||
self.emphasis.multipliers = torch.asarray(batch_multipliers)
|
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
|
||||||
self.emphasis.z = z
|
self.emphasis.z = z
|
||||||
self.emphasis.after_transformers()
|
self.emphasis.after_transformers()
|
||||||
z = self.emphasis.z
|
z = self.emphasis.z
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class EmbeddingDatabase:
|
|||||||
return self.register_embedding_by_name(embedding, embedding.name)
|
return self.register_embedding_by_name(embedding, embedding.name)
|
||||||
|
|
||||||
def register_embedding_by_name(self, embedding, name):
|
def register_embedding_by_name(self, embedding, name):
|
||||||
ids = self.tokenizer.tokenize([name])[0]
|
ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0]
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
|
|||||||
@@ -498,8 +498,14 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
|
import backend.text_processing.classic_engine
|
||||||
|
last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy()
|
||||||
|
|
||||||
|
modules.sd_hijack.model_hijack.extra_generation_params.update(last_extra_generation_params)
|
||||||
|
|
||||||
if len(cache) > 2:
|
if len(cache) > 2:
|
||||||
cache[2] = modules.sd_hijack.model_hijack.extra_generation_params
|
cache[2] = last_extra_generation_params
|
||||||
|
|
||||||
cache[0] = cached_params
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
@@ -880,7 +886,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
# todo: reload ti
|
||||||
|
# model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
pass
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.process(p)
|
p.scripts.process(p)
|
||||||
|
|||||||
@@ -127,14 +127,9 @@ class StableDiffusionModelHijack:
|
|||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import modules.textual_inversion.textual_inversion
|
|
||||||
|
|
||||||
self.extra_generation_params = {}
|
self.extra_generation_params = {}
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
|
||||||
|
|
||||||
def apply_optimizations(self, option=None):
|
def apply_optimizations(self, option=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -686,19 +686,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
model_data.set_sd_model(sd_model)
|
model_data.set_sd_model(sd_model)
|
||||||
model_data.was_loaded_at_least_once = True
|
model_data.was_loaded_at_least_once = True
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
|
||||||
|
|
||||||
timer.record("load textual inversion embeddings")
|
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
timer.record("scripts callbacks")
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
|
||||||
|
|
||||||
timer.record("calculate empty prompt")
|
|
||||||
|
|
||||||
print(f"Model loaded in {timer.summary()}.")
|
print(f"Model loaded in {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class EmbeddingDatabase:
|
|||||||
return self.register_embedding_by_name(embedding, model, embedding.name)
|
return self.register_embedding_by_name(embedding, model, embedding.name)
|
||||||
|
|
||||||
def register_embedding_by_name(self, embedding, model, name):
|
def register_embedding_by_name(self, embedding, model, name):
|
||||||
ids = model.cond_stage_model.tokenize([name])[0]
|
ids = [0, 0, 0] # model.cond_stage_model.tokenize([name])[0]
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
@@ -183,11 +183,7 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||||
|
self.register_embedding(embedding, None)
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
|
||||||
else:
|
|
||||||
self.skipped_embeddings[name] = embedding
|
|
||||||
else:
|
else:
|
||||||
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
|
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import backend.nn.unet
|
|||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from modules.sd_models_config import find_checkpoint_config
|
from modules.sd_models_config import find_checkpoint_config
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts, opts
|
||||||
from modules import sd_hijack
|
from modules import sd_hijack
|
||||||
from modules.sd_models_xl import extend_sdxl
|
from modules.sd_models_xl import extend_sdxl
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
@@ -17,6 +17,7 @@ from modules_forge import clip
|
|||||||
from modules_forge.unet_patcher import UnetPatcher
|
from modules_forge.unet_patcher import UnetPatcher
|
||||||
from backend.loader import load_huggingface_components
|
from backend.loader import load_huggingface_components
|
||||||
from backend.modules.k_model import KModel
|
from backend.modules.k_model import KModel
|
||||||
|
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||||
|
|
||||||
import open_clip
|
import open_clip
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
@@ -148,6 +149,15 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
|||||||
sd_model.first_stage_model = forge_objects.vae.first_stage_model
|
sd_model.first_stage_model = forge_objects.vae.first_stage_model
|
||||||
sd_model.model.diffusion_model = forge_objects.unet.model
|
sd_model.model.diffusion_model = forge_objects.unet.model
|
||||||
|
|
||||||
|
def set_clip_skip_callback(m, ts):
|
||||||
|
m.clip_skip = opts.CLIP_stop_at_last_layers
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_clip_skip_callback_and_move_model(m, ts):
|
||||||
|
memory_management.load_model_gpu(sd_model.forge_objects.clip.patcher)
|
||||||
|
m.clip_skip = opts.CLIP_stop_at_last_layers
|
||||||
|
return
|
||||||
|
|
||||||
conditioner = getattr(sd_model, 'conditioner', None)
|
conditioner = getattr(sd_model, 'conditioner', None)
|
||||||
if conditioner:
|
if conditioner:
|
||||||
text_cond_models = []
|
text_cond_models = []
|
||||||
@@ -156,23 +166,44 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
|||||||
embedder = conditioner.embedders[i]
|
embedder = conditioner.embedders[i]
|
||||||
typename = type(embedder).__name__
|
typename = type(embedder).__name__
|
||||||
if typename == 'FrozenCLIPEmbedder': # SDXL Clip L
|
if typename == 'FrozenCLIPEmbedder': # SDXL Clip L
|
||||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l
|
engine = ClassicTextProcessingEngine(
|
||||||
embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
|
||||||
model_embeddings = embedder.transformer.text_model.embeddings
|
tokenizer=forge_objects.clip.tokenizer.clip_l,
|
||||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
embedding_dir=cmd_opts.embeddings_dir,
|
||||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
embedding_key='clip_l',
|
||||||
embedder = clip.CLIP_SD_XL_L(embedder, sd_hijack.model_hijack)
|
embedding_expected_shape=2048,
|
||||||
conditioner.embedders[i] = embedder
|
emphasis_name=opts.emphasis,
|
||||||
|
text_projection=False,
|
||||||
|
minimal_clip_skip=2,
|
||||||
|
clip_skip=2,
|
||||||
|
return_pooled=False,
|
||||||
|
final_layer_norm=False,
|
||||||
|
callback_before_encode=set_clip_skip_callback
|
||||||
|
)
|
||||||
|
engine.is_trainable = False # for sgm codebase
|
||||||
|
engine.legacy_ucg_val = None # for sgm codebase
|
||||||
|
engine.input_key = 'txt' # for sgm codebase
|
||||||
|
conditioner.embedders[i] = engine
|
||||||
text_cond_models.append(embedder)
|
text_cond_models.append(embedder)
|
||||||
elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G
|
elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G
|
||||||
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g
|
engine = ClassicTextProcessingEngine(
|
||||||
embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer
|
text_encoder=forge_objects.clip.cond_stage_model.clip_g,
|
||||||
embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection
|
tokenizer=forge_objects.clip.tokenizer.clip_g,
|
||||||
model_embeddings = embedder.transformer.text_model.embeddings
|
embedding_dir=cmd_opts.embeddings_dir,
|
||||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
embedding_key='clip_g',
|
||||||
model_embeddings.token_embedding, sd_hijack.model_hijack, textual_inversion_key='clip_g')
|
embedding_expected_shape=2048,
|
||||||
embedder = clip.CLIP_SD_XL_G(embedder, sd_hijack.model_hijack)
|
emphasis_name=opts.emphasis,
|
||||||
conditioner.embedders[i] = embedder
|
text_projection=True,
|
||||||
|
minimal_clip_skip=2,
|
||||||
|
clip_skip=2,
|
||||||
|
return_pooled=True,
|
||||||
|
final_layer_norm=False,
|
||||||
|
callback_before_encode=set_clip_skip_callback
|
||||||
|
)
|
||||||
|
engine.is_trainable = False # for sgm codebase
|
||||||
|
engine.legacy_ucg_val = None # for sgm codebase
|
||||||
|
engine.input_key = 'txt' # for sgm codebase
|
||||||
|
conditioner.embedders[i] = engine
|
||||||
text_cond_models.append(embedder)
|
text_cond_models.append(embedder)
|
||||||
|
|
||||||
if len(text_cond_models) == 1:
|
if len(text_cond_models) == 1:
|
||||||
@@ -180,19 +211,37 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
|||||||
else:
|
else:
|
||||||
sd_model.cond_stage_model = conditioner
|
sd_model.cond_stage_model = conditioner
|
||||||
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip
|
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip
|
||||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
|
engine = ClassicTextProcessingEngine(
|
||||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
|
||||||
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
|
tokenizer=forge_objects.clip.tokenizer.clip_l,
|
||||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
embedding_dir=cmd_opts.embeddings_dir,
|
||||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
embedding_key='clip_l',
|
||||||
sd_model.cond_stage_model = clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
|
embedding_expected_shape=768,
|
||||||
|
emphasis_name=opts.emphasis,
|
||||||
|
text_projection=False,
|
||||||
|
minimal_clip_skip=1,
|
||||||
|
clip_skip=1,
|
||||||
|
return_pooled=False,
|
||||||
|
final_layer_norm=True,
|
||||||
|
callback_before_encode=set_clip_skip_callback_and_move_model
|
||||||
|
)
|
||||||
|
sd_model.cond_stage_model = engine
|
||||||
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
|
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
|
||||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
|
engine = ClassicTextProcessingEngine(
|
||||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
|
||||||
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
|
tokenizer=forge_objects.clip.tokenizer.clip_l,
|
||||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
embedding_dir=cmd_opts.embeddings_dir,
|
||||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
embedding_key='clip_l',
|
||||||
sd_model.cond_stage_model = clip.CLIP_SD_21_H(sd_model.cond_stage_model, sd_hijack.model_hijack)
|
embedding_expected_shape=1024,
|
||||||
|
emphasis_name=opts.emphasis,
|
||||||
|
text_projection=False,
|
||||||
|
minimal_clip_skip=1,
|
||||||
|
clip_skip=1,
|
||||||
|
return_pooled=False,
|
||||||
|
final_layer_norm=True,
|
||||||
|
callback_before_encode=set_clip_skip_callback_and_move_model
|
||||||
|
)
|
||||||
|
sd_model.cond_stage_model = engine
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)
|
raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user