From 8c118df739c4744f27db931c72f6a5a277631bbd Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sun, 4 Aug 2024 13:19:32 -0700 Subject: [PATCH] emphasis --- backend/text_processing/engine.py | 32 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/backend/text_processing/engine.py b/backend/text_processing/engine.py index b19b30da..07cd2949 100644 --- a/backend/text_processing/engine.py +++ b/backend/text_processing/engine.py @@ -46,7 +46,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): class ClassicTextProcessingEngine: - def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768): + def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"): super().__init__() self.chunk_length = chunk_length @@ -55,6 +55,7 @@ class ClassicTextProcessingEngine: self.embeddings.load_textual_inversion_embeddings() self.text_encoder = text_encoder self.tokenizer = tokenizer + self.emphasis = emphasis.get_current_option(emphasis_name) model_embeddings = text_encoder.text_model.embeddings model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key) @@ -84,6 +85,7 @@ class ClassicTextProcessingEngine: self.id_start = self.tokenizer.bos_token_id self.id_end = self.tokenizer.eos_token_id self.id_pad = self.id_end + self.return_pooled = True # Todo: remove these self.legacy_ucg_val = None # for sgm codebase @@ -98,7 +100,7 @@ class ClassicTextProcessingEngine: return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize(self, texts): - tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] return tokenized @@ -106,8 +108,8 @@ class ClassicTextProcessingEngine: raise NotImplementedError def encode_embedding_init_text(self, init_text, nvpt): - embedding_layer = self.wrapped.transformer.text_model.embeddings - ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedding_layer = self.text_encoder.transformer.text_model.embeddings + ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded @@ -122,8 +124,6 @@ class ClassicTextProcessingEngine: last_comma = -1 def next_chunk(is_last=False): - """puts current chunk into the list of results and produces the next one - empty; - if is_last is true, tokens tokens at the end won't add to token_count""" nonlocal token_count nonlocal last_comma nonlocal chunk @@ -175,7 +175,7 @@ class ClassicTextProcessingEngine: if len(chunk.tokens) == self.chunk_length: next_chunk() - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) + embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position) if embedding is None: chunk.tokens.append(token) chunk.multipliers.append(weight) @@ -227,9 +227,9 @@ class ClassicTextProcessingEngine: tokens = [x.tokens for x in batch_chunk] multipliers = [x.multipliers for x in batch_chunk] - self.hijack.fixes = [x.fixes for x in batch_chunk] + self.embeddings.fixes = [x.fixes for x in batch_chunk] - for fixes in self.hijack.fixes: + for fixes in self.embeddings.fixes: for _position, embedding in fixes: used_embeddings[embedding.name] = embedding @@ -276,15 +276,11 @@ class ClassicTextProcessingEngine: pooled = getattr(z, 'pooled', None) - # Todo - # e = emphasis.get_current_option(opts.emphasis)() - - e = emphasis.EmphasisOriginal() - e.tokens = remade_batch_tokens - e.multipliers = torch.asarray(batch_multipliers) - e.z = z - e.after_transformers() - z = e.z + self.emphasis.tokens = remade_batch_tokens + self.emphasis.multipliers = torch.asarray(batch_multipliers) + self.emphasis.z = z + self.emphasis.after_transformers() + z = self.emphasis.z if pooled is not None: z.pooled = pooled