mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-09 23:19:48 +00:00
emphasis
This commit is contained in:
@@ -46,7 +46,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768):
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
@@ -55,6 +55,7 @@ class ClassicTextProcessingEngine:
|
||||
self.embeddings.load_textual_inversion_embeddings()
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)
|
||||
|
||||
model_embeddings = text_encoder.text_model.embeddings
|
||||
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||
@@ -84,6 +85,7 @@ class ClassicTextProcessingEngine:
|
||||
self.id_start = self.tokenizer.bos_token_id
|
||||
self.id_end = self.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
self.return_pooled = True
|
||||
|
||||
# Todo: remove these
|
||||
self.legacy_ucg_val = None # for sgm codebase
|
||||
@@ -98,7 +100,7 @@ class ClassicTextProcessingEngine:
|
||||
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
@@ -106,8 +108,8 @@ class ClassicTextProcessingEngine:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedding_layer = self.text_encoder.transformer.text_model.embeddings
|
||||
ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
return embedded
|
||||
|
||||
@@ -122,8 +124,6 @@ class ClassicTextProcessingEngine:
|
||||
last_comma = -1
|
||||
|
||||
def next_chunk(is_last=False):
|
||||
"""puts current chunk into the list of results and produces the next one - empty;
|
||||
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
||||
nonlocal token_count
|
||||
nonlocal last_comma
|
||||
nonlocal chunk
|
||||
@@ -175,7 +175,7 @@ class ClassicTextProcessingEngine:
|
||||
if len(chunk.tokens) == self.chunk_length:
|
||||
next_chunk()
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||
embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position)
|
||||
if embedding is None:
|
||||
chunk.tokens.append(token)
|
||||
chunk.multipliers.append(weight)
|
||||
@@ -227,9 +227,9 @@ class ClassicTextProcessingEngine:
|
||||
|
||||
tokens = [x.tokens for x in batch_chunk]
|
||||
multipliers = [x.multipliers for x in batch_chunk]
|
||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
self.embeddings.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
for fixes in self.hijack.fixes:
|
||||
for fixes in self.embeddings.fixes:
|
||||
for _position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
@@ -276,15 +276,11 @@ class ClassicTextProcessingEngine:
|
||||
|
||||
pooled = getattr(z, 'pooled', None)
|
||||
|
||||
# Todo
|
||||
# e = emphasis.get_current_option(opts.emphasis)()
|
||||
|
||||
e = emphasis.EmphasisOriginal()
|
||||
e.tokens = remade_batch_tokens
|
||||
e.multipliers = torch.asarray(batch_multipliers)
|
||||
e.z = z
|
||||
e.after_transformers()
|
||||
z = e.z
|
||||
self.emphasis.tokens = remade_batch_tokens
|
||||
self.emphasis.multipliers = torch.asarray(batch_multipliers)
|
||||
self.emphasis.z = z
|
||||
self.emphasis.after_transformers()
|
||||
z = self.emphasis.z
|
||||
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
||||
Reference in New Issue
Block a user