diff --git a/README.md b/README.md index 4bd04453..baeef097 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ WebUI Forge is under a week of major revision right now between 2024 Aug 1 and Aug 7. To join the test, just update to the latest unstable version. -**Current Progress (2024 Aug 3):** Backend Rewrite is 81% finished - remaining 30 hours to begin making it stable; remaining 48 hours to begin supporting many new things. +**Current Progress (2024 Aug 3):** Backend Rewrite is 85% finished - remaining 30 hours to begin making it stable; remaining 48 hours to begin supporting many new things. For downloading previous versions, see [Previous Versions](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/849). diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index 241d70cf..dcc7695b 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -3,10 +3,11 @@ import torch from collections import namedtuple from backend.text_processing import parsing, emphasis -from textual_inversion import EmbeddingDatabase +from backend.text_processing.textual_inversion import EmbeddingDatabase PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +last_extra_generation_params = {} class PromptChunk: @@ -37,6 +38,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec + emb = emb.to(inputs_embeds) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) @@ -45,8 +47,11 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): return torch.stack(vecs) -class ClassicTextProcessingEngine: - def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None): +class ClassicTextProcessingEngine(torch.nn.Module): + def __init__(self, text_encoder, tokenizer, chunk_length=75, + embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", + text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True, + callback_before_encode=None): super().__init__() self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape) @@ -56,20 +61,21 @@ class ClassicTextProcessingEngine: self.text_encoder = text_encoder self.tokenizer = tokenizer - self.emphasis = emphasis.get_current_option(emphasis_name) + self.emphasis = emphasis.get_current_option(emphasis_name)() self.text_projection = text_projection self.minimal_clip_skip = minimal_clip_skip self.clip_skip = clip_skip self.return_pooled = return_pooled + self.final_layer_norm = final_layer_norm self.callback_before_encode = callback_before_encode self.chunk_length = chunk_length self.id_start = self.tokenizer.bos_token_id self.id_end = self.tokenizer.eos_token_id - self.id_pad = self.id_end + self.id_pad = self.tokenizer.pad_token_id - model_embeddings = text_encoder.text_model.embeddings + model_embeddings = text_encoder.transformer.text_model.embeddings model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key) vocab = self.tokenizer.get_vocab() @@ -94,9 +100,6 @@ class ClassicTextProcessingEngine: if mult != 1.0: self.token_mults[ident] = mult - # # Todo: remove these - # self.legacy_ucg_val = None # for sgm codebase - def empty_chunk(self): chunk = PromptChunk() chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) @@ -112,27 +115,25 @@ class ClassicTextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - self.text_encoder.transformer.text_model.embeddings.to(tokens.device) + tokens = tokens.to(self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device) + outputs = self.text_encoder.transformer(tokens, output_hidden_states=True) layer_id = - max(self.clip_skip, self.minimal_clip_skip) z = outputs.hidden_states[layer_id] + if self.final_layer_norm: + z = self.text_encoder.transformer.text_model.final_layer_norm(z) + if self.return_pooled: pooled_output = outputs.pooler_output if self.text_projection: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float() z.pooled = pooled_output return z - def encode_embedding_init_text(self, init_text, nvpt): - embedding_layer = self.text_encoder.transformer.text_model.embeddings - ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) - return embedded - def tokenize_line(self, line): parsed = parsing.parse_prompt_attention(line) @@ -235,9 +236,9 @@ class ClassicTextProcessingEngine: return batch_chunks, token_count - def __call__(self, texts): + def forward(self, texts): if self.callback_before_encode is not None: - self.callback_before_encode() + self.callback_before_encode(self, texts) batch_chunks, token_count = self.process_texts(texts) @@ -259,28 +260,21 @@ class ClassicTextProcessingEngine: z = self.process_tokens(tokens, multipliers) zs.append(z) + global last_extra_generation_params + + last_extra_generation_params = {} + if used_embeddings: + names = [] + for name, embedding in used_embeddings.items(): print(f'Used Embedding: {name}') + names.append(name.replace(":", "").replace(",", "")) - # Todo: - # if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: - # hashes = [] - # for name, embedding in used_embeddings.items(): - # shorthash = embedding.shorthash - # if not shorthash: - # continue - # - # name = name.replace(":", "").replace(",", "") - # hashes.append(f"{name}: {shorthash}") - # - # if hashes: - # if self.hijack.extra_generation_params.get("TI hashes"): - # hashes.append(self.hijack.extra_generation_params.get("TI hashes")) - # self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) - # - # if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": - # self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + last_extra_generation_params["TI"] = ", ".join(names) + + if any(x for x in texts if "(" in x or "[" in x) and self.emphasis.name != "Original": + last_extra_generation_params["Emphasis"] = self.emphasis.name if self.return_pooled: return torch.hstack(zs), zs[0].pooled @@ -300,7 +294,7 @@ class ClassicTextProcessingEngine: pooled = getattr(z, 'pooled', None) self.emphasis.tokens = remade_batch_tokens - self.emphasis.multipliers = torch.asarray(batch_multipliers) + self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z) self.emphasis.z = z self.emphasis.after_transformers() z = self.emphasis.z diff --git a/backend/text_processing/textual_inversion.py b/backend/text_processing/textual_inversion.py index cd4488b4..2a95513d 100644 --- a/backend/text_processing/textual_inversion.py +++ b/backend/text_processing/textual_inversion.py @@ -128,7 +128,7 @@ class EmbeddingDatabase: return self.register_embedding_by_name(embedding, embedding.name) def register_embedding_by_name(self, embedding, name): - ids = self.tokenizer.tokenize([name])[0] + ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] diff --git a/modules/processing.py b/modules/processing.py index df2795e2..08a3b781 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -498,8 +498,14 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) + + import backend.text_processing.classic_engine + last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy() + + modules.sd_hijack.model_hijack.extra_generation_params.update(last_extra_generation_params) + if len(cache) > 2: - cache[2] = modules.sd_hijack.model_hijack.extra_generation_params + cache[2] = last_extra_generation_params cache[0] = cached_params return cache[1] @@ -880,7 +886,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: - model_hijack.embedding_db.load_textual_inversion_embeddings() + # todo: reload ti + # model_hijack.embedding_db.load_textual_inversion_embeddings() + pass if p.scripts is not None: p.scripts.process(p) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 8050b278..ffaadc89 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -127,14 +127,9 @@ class StableDiffusionModelHijack: optimization_method = None def __init__(self): - import modules.textual_inversion.textual_inversion - self.extra_generation_params = {} self.comments = [] - self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) - def apply_optimizations(self, option=None): pass diff --git a/modules/sd_models.py b/modules/sd_models.py index 98e5249c..92537ece 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -686,19 +686,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model - - timer.record("load textual inversion embeddings") - script_callbacks.model_loaded_callback(sd_model) timer.record("scripts callbacks") - with torch.no_grad(): - sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) - - timer.record("calculate empty prompt") - print(f"Model loaded in {timer.summary()}.") return sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4aa14fe4..3a4cbf9b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -127,7 +127,7 @@ class EmbeddingDatabase: return self.register_embedding_by_name(embedding, model, embedding.name) def register_embedding_by_name(self, embedding, model, name): - ids = model.cond_stage_model.tokenize([name])[0] + ids = [0, 0, 0] # model.cond_stage_model.tokenize([name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] @@ -183,11 +183,7 @@ class EmbeddingDatabase: if data is not None: embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) - - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) - else: - self.skipped_embeddings[name] = embedding + self.register_embedding(embedding, None) else: print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") diff --git a/modules_forge/loader.py b/modules_forge/loader.py index 074a31bb..93c7fad8 100644 --- a/modules_forge/loader.py +++ b/modules_forge/loader.py @@ -9,7 +9,7 @@ import backend.nn.unet from omegaconf import OmegaConf from modules.sd_models_config import find_checkpoint_config -from modules.shared import cmd_opts +from modules.shared import cmd_opts, opts from modules import sd_hijack from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config @@ -17,6 +17,7 @@ from modules_forge import clip from modules_forge.unet_patcher import UnetPatcher from backend.loader import load_huggingface_components from backend.modules.k_model import KModel +from backend.text_processing.classic_engine import ClassicTextProcessingEngine import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -148,6 +149,15 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.first_stage_model = forge_objects.vae.first_stage_model sd_model.model.diffusion_model = forge_objects.unet.model + def set_clip_skip_callback(m, ts): + m.clip_skip = opts.CLIP_stop_at_last_layers + return + + def set_clip_skip_callback_and_move_model(m, ts): + memory_management.load_model_gpu(sd_model.forge_objects.clip.patcher) + m.clip_skip = opts.CLIP_stop_at_last_layers + return + conditioner = getattr(sd_model, 'conditioner', None) if conditioner: text_cond_models = [] @@ -156,23 +166,44 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenCLIPEmbedder': # SDXL Clip L - embedder.tokenizer = forge_objects.clip.tokenizer.clip_l - embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( - model_embeddings.token_embedding, sd_hijack.model_hijack) - embedder = clip.CLIP_SD_XL_L(embedder, sd_hijack.model_hijack) - conditioner.embedders[i] = embedder + engine = ClassicTextProcessingEngine( + text_encoder=forge_objects.clip.cond_stage_model.clip_l, + tokenizer=forge_objects.clip.tokenizer.clip_l, + embedding_dir=cmd_opts.embeddings_dir, + embedding_key='clip_l', + embedding_expected_shape=2048, + emphasis_name=opts.emphasis, + text_projection=False, + minimal_clip_skip=2, + clip_skip=2, + return_pooled=False, + final_layer_norm=False, + callback_before_encode=set_clip_skip_callback + ) + engine.is_trainable = False # for sgm codebase + engine.legacy_ucg_val = None # for sgm codebase + engine.input_key = 'txt' # for sgm codebase + conditioner.embedders[i] = engine text_cond_models.append(embedder) elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G - embedder.tokenizer = forge_objects.clip.tokenizer.clip_g - embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer - embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( - model_embeddings.token_embedding, sd_hijack.model_hijack, textual_inversion_key='clip_g') - embedder = clip.CLIP_SD_XL_G(embedder, sd_hijack.model_hijack) - conditioner.embedders[i] = embedder + engine = ClassicTextProcessingEngine( + text_encoder=forge_objects.clip.cond_stage_model.clip_g, + tokenizer=forge_objects.clip.tokenizer.clip_g, + embedding_dir=cmd_opts.embeddings_dir, + embedding_key='clip_g', + embedding_expected_shape=2048, + emphasis_name=opts.emphasis, + text_projection=True, + minimal_clip_skip=2, + clip_skip=2, + return_pooled=True, + final_layer_norm=False, + callback_before_encode=set_clip_skip_callback + ) + engine.is_trainable = False # for sgm codebase + engine.legacy_ucg_val = None # for sgm codebase + engine.input_key = 'txt' # for sgm codebase + conditioner.embedders[i] = engine text_cond_models.append(embedder) if len(text_cond_models) == 1: @@ -180,19 +211,37 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): else: sd_model.cond_stage_model = conditioner elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip - sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l - sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer - model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings - model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( - model_embeddings.token_embedding, sd_hijack.model_hijack) - sd_model.cond_stage_model = clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack) + engine = ClassicTextProcessingEngine( + text_encoder=forge_objects.clip.cond_stage_model.clip_l, + tokenizer=forge_objects.clip.tokenizer.clip_l, + embedding_dir=cmd_opts.embeddings_dir, + embedding_key='clip_l', + embedding_expected_shape=768, + emphasis_name=opts.emphasis, + text_projection=False, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=False, + final_layer_norm=True, + callback_before_encode=set_clip_skip_callback_and_move_model + ) + sd_model.cond_stage_model = engine elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip - sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l - sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer - model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings - model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( - model_embeddings.token_embedding, sd_hijack.model_hijack) - sd_model.cond_stage_model = clip.CLIP_SD_21_H(sd_model.cond_stage_model, sd_hijack.model_hijack) + engine = ClassicTextProcessingEngine( + text_encoder=forge_objects.clip.cond_stage_model.clip_l, + tokenizer=forge_objects.clip.tokenizer.clip_l, + embedding_dir=cmd_opts.embeddings_dir, + embedding_key='clip_l', + embedding_expected_shape=1024, + emphasis_name=opts.emphasis, + text_projection=False, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=False, + final_layer_norm=True, + callback_before_encode=set_clip_skip_callback_and_move_model + ) + sd_model.cond_stage_model = engine else: raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)