mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-23 08:13:58 +00:00
Text Processing Engine is Finished
100% reproduce all previous results, including TI embeddings, LoRAs in CLIP, emphasize settings, BREAK, timestep swap scheduling, AB mixture, advanced uncond, etc Backend is 85% finished
This commit is contained in:
@@ -3,10 +3,11 @@ import torch
|
||||
|
||||
from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from textual_inversion import EmbeddingDatabase
|
||||
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
last_extra_generation_params = {}
|
||||
|
||||
|
||||
class PromptChunk:
|
||||
@@ -37,6 +38,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = emb.to(inputs_embeds)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
|
||||
@@ -45,8 +47,11 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None):
|
||||
class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75,
|
||||
embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
|
||||
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
|
||||
callback_before_encode=None):
|
||||
super().__init__()
|
||||
|
||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||
@@ -56,20 +61,21 @@ class ClassicTextProcessingEngine:
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)()
|
||||
self.text_projection = text_projection
|
||||
self.minimal_clip_skip = minimal_clip_skip
|
||||
self.clip_skip = clip_skip
|
||||
self.return_pooled = return_pooled
|
||||
self.final_layer_norm = final_layer_norm
|
||||
self.callback_before_encode = callback_before_encode
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
|
||||
self.id_start = self.tokenizer.bos_token_id
|
||||
self.id_end = self.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
self.id_pad = self.tokenizer.pad_token_id
|
||||
|
||||
model_embeddings = text_encoder.text_model.embeddings
|
||||
model_embeddings = text_encoder.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
@@ -94,9 +100,6 @@ class ClassicTextProcessingEngine:
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
# # Todo: remove these
|
||||
# self.legacy_ucg_val = None # for sgm codebase
|
||||
|
||||
def empty_chunk(self):
|
||||
chunk = PromptChunk()
|
||||
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||
@@ -112,27 +115,25 @@ class ClassicTextProcessingEngine:
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
self.text_encoder.transformer.text_model.embeddings.to(tokens.device)
|
||||
tokens = tokens.to(self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device)
|
||||
|
||||
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
|
||||
|
||||
layer_id = - max(self.clip_skip, self.minimal_clip_skip)
|
||||
z = outputs.hidden_states[layer_id]
|
||||
|
||||
if self.final_layer_norm:
|
||||
z = self.text_encoder.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
if self.return_pooled:
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
|
||||
|
||||
z.pooled = pooled_output
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.text_encoder.transformer.text_model.embeddings
|
||||
ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
return embedded
|
||||
|
||||
def tokenize_line(self, line):
|
||||
parsed = parsing.parse_prompt_attention(line)
|
||||
|
||||
@@ -235,9 +236,9 @@ class ClassicTextProcessingEngine:
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def __call__(self, texts):
|
||||
def forward(self, texts):
|
||||
if self.callback_before_encode is not None:
|
||||
self.callback_before_encode()
|
||||
self.callback_before_encode(self, texts)
|
||||
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
@@ -259,28 +260,21 @@ class ClassicTextProcessingEngine:
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
zs.append(z)
|
||||
|
||||
global last_extra_generation_params
|
||||
|
||||
last_extra_generation_params = {}
|
||||
|
||||
if used_embeddings:
|
||||
names = []
|
||||
|
||||
for name, embedding in used_embeddings.items():
|
||||
print(f'Used Embedding: {name}')
|
||||
names.append(name.replace(":", "").replace(",", ""))
|
||||
|
||||
# Todo:
|
||||
# if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
||||
# hashes = []
|
||||
# for name, embedding in used_embeddings.items():
|
||||
# shorthash = embedding.shorthash
|
||||
# if not shorthash:
|
||||
# continue
|
||||
#
|
||||
# name = name.replace(":", "").replace(",", "")
|
||||
# hashes.append(f"{name}: {shorthash}")
|
||||
#
|
||||
# if hashes:
|
||||
# if self.hijack.extra_generation_params.get("TI hashes"):
|
||||
# hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
# self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
#
|
||||
# if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
# self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
last_extra_generation_params["TI"] = ", ".join(names)
|
||||
|
||||
if any(x for x in texts if "(" in x or "[" in x) and self.emphasis.name != "Original":
|
||||
last_extra_generation_params["Emphasis"] = self.emphasis.name
|
||||
|
||||
if self.return_pooled:
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
@@ -300,7 +294,7 @@ class ClassicTextProcessingEngine:
|
||||
pooled = getattr(z, 'pooled', None)
|
||||
|
||||
self.emphasis.tokens = remade_batch_tokens
|
||||
self.emphasis.multipliers = torch.asarray(batch_multipliers)
|
||||
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
|
||||
self.emphasis.z = z
|
||||
self.emphasis.after_transformers()
|
||||
z = self.emphasis.z
|
||||
|
||||
Reference in New Issue
Block a user