From cb9b1556454f311e31a58c7df03462c42ba876c4 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sun, 4 Aug 2024 13:10:08 -0700 Subject: [PATCH] add embedding layer impl --- backend/text_processing/engine.py | 65 +++++++++++++------- backend/text_processing/textual_inversion.py | 1 + 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/backend/text_processing/engine.py b/backend/text_processing/engine.py index d7a87963..b19b30da 100644 --- a/backend/text_processing/engine.py +++ b/backend/text_processing/engine.py @@ -1,9 +1,9 @@ import math -from collections import namedtuple - import torch +from collections import namedtuple from backend.text_processing import parsing, emphasis +from textual_inversion import EmbeddingDatabase PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) @@ -16,28 +16,48 @@ class PromptChunk: self.fixes = [] -class ClassicTextProcessingEngine(torch.nn.Module): - def __init__(self, wrapped, hijack): +class CLIPEmbeddingForTextualInversion(torch.nn.Module): + def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): super().__init__() - self.chunk_length = 75 - - self.is_trainable = False - self.input_key = 'txt' - self.return_pooled = False - - self.comma_token = None - - self.hijack = hijack - self.wrapped = wrapped + self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key + self.weight = self.wrapped.weight - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.return_pooled = getattr(self.wrapped, 'return_pooled', False) + def forward(self, input_ids): + batch_fixes = self.embeddings.fixes + self.embeddings.fixes = None - self.legacy_ucg_val = None # for sgm codebase + inputs_embeds = self.wrapped(input_ids) - self.tokenizer = wrapped.tokenizer + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) + + vecs.append(tensor) + + return torch.stack(vecs) + + +class ClassicTextProcessingEngine: + def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768): + super().__init__() + + self.chunk_length = chunk_length + self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape) + self.embeddings.add_embedding_dir(embedding_dir) + self.embeddings.load_textual_inversion_embeddings() + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + model_embeddings = text_encoder.text_model.embeddings + model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key) vocab = self.tokenizer.get_vocab() @@ -61,10 +81,13 @@ class ClassicTextProcessingEngine(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - self.id_start = self.wrapped.tokenizer.bos_token_id - self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_start = self.tokenizer.bos_token_id + self.id_end = self.tokenizer.eos_token_id self.id_pad = self.id_end + # Todo: remove these + self.legacy_ucg_val = None # for sgm codebase + def empty_chunk(self): chunk = PromptChunk() chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) diff --git a/backend/text_processing/textual_inversion.py b/backend/text_processing/textual_inversion.py index d9d6c891..cd4488b4 100644 --- a/backend/text_processing/textual_inversion.py +++ b/backend/text_processing/textual_inversion.py @@ -116,6 +116,7 @@ class EmbeddingDatabase: self.skipped_embeddings = {} self.expected_shape = expected_shape self.tokenizer = tokenizer + self.fixes = [] def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)