add embedding layer impl

This commit is contained in:
layerdiffusion
2024-08-04 13:10:08 -07:00
parent 21c8608373
commit cb9b155645
2 changed files with 45 additions and 21 deletions

View File

@@ -1,9 +1,9 @@
import math
from collections import namedtuple
import torch
from collections import namedtuple
from backend.text_processing import parsing, emphasis
from textual_inversion import EmbeddingDatabase
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
@@ -16,28 +16,48 @@ class PromptChunk:
self.fixes = []
class ClassicTextProcessingEngine(torch.nn.Module):
def __init__(self, wrapped, hijack):
class CLIPEmbeddingForTextualInversion(torch.nn.Module):
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.chunk_length = 75
self.is_trainable = False
self.input_key = 'txt'
self.return_pooled = False
self.comma_token = None
self.hijack = hijack
self.wrapped = wrapped
self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
self.weight = self.wrapped.weight
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
self.legacy_ucg_val = None # for sgm codebase
inputs_embeds = self.wrapped(input_ids)
self.tokenizer = wrapped.tokenizer
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
vecs.append(tensor)
return torch.stack(vecs)
class ClassicTextProcessingEngine:
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768):
super().__init__()
self.chunk_length = chunk_length
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
self.embeddings.add_embedding_dir(embedding_dir)
self.embeddings.load_textual_inversion_embeddings()
self.text_encoder = text_encoder
self.tokenizer = tokenizer
model_embeddings = text_encoder.text_model.embeddings
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
vocab = self.tokenizer.get_vocab()
@@ -61,10 +81,13 @@ class ClassicTextProcessingEngine(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_start = self.tokenizer.bos_token_id
self.id_end = self.tokenizer.eos_token_id
self.id_pad = self.id_end
# Todo: remove these
self.legacy_ucg_val = None # for sgm codebase
def empty_chunk(self):
chunk = PromptChunk()
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)