mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
add embedding layer impl
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
import math
|
import math
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
from backend.text_processing import parsing, emphasis
|
from backend.text_processing import parsing, emphasis
|
||||||
|
from textual_inversion import EmbeddingDatabase
|
||||||
|
|
||||||
|
|
||||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
@@ -16,28 +16,48 @@ class PromptChunk:
|
|||||||
self.fixes = []
|
self.fixes = []
|
||||||
|
|
||||||
|
|
||||||
class ClassicTextProcessingEngine(torch.nn.Module):
|
class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_length = 75
|
|
||||||
|
|
||||||
self.is_trainable = False
|
|
||||||
self.input_key = 'txt'
|
|
||||||
self.return_pooled = False
|
|
||||||
|
|
||||||
self.comma_token = None
|
|
||||||
|
|
||||||
self.hijack = hijack
|
|
||||||
|
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.textual_inversion_key = textual_inversion_key
|
||||||
|
self.weight = self.wrapped.weight
|
||||||
|
|
||||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
def forward(self, input_ids):
|
||||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
batch_fixes = self.embeddings.fixes
|
||||||
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
|
self.embeddings.fixes = None
|
||||||
|
|
||||||
self.legacy_ucg_val = None # for sgm codebase
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
self.tokenizer = wrapped.tokenizer
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
vecs = []
|
||||||
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
for offset, embedding in fixes:
|
||||||
|
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||||
|
|
||||||
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassicTextProcessingEngine:
|
||||||
|
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||||
|
self.embeddings.add_embedding_dir(embedding_dir)
|
||||||
|
self.embeddings.load_textual_inversion_embeddings()
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
model_embeddings = text_encoder.text_model.embeddings
|
||||||
|
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||||
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
|
||||||
@@ -61,10 +81,13 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
|||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
self.id_start = self.tokenizer.bos_token_id
|
||||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
self.id_end = self.tokenizer.eos_token_id
|
||||||
self.id_pad = self.id_end
|
self.id_pad = self.id_end
|
||||||
|
|
||||||
|
# Todo: remove these
|
||||||
|
self.legacy_ucg_val = None # for sgm codebase
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
chunk = PromptChunk()
|
chunk = PromptChunk()
|
||||||
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class EmbeddingDatabase:
|
|||||||
self.skipped_embeddings = {}
|
self.skipped_embeddings = {}
|
||||||
self.expected_shape = expected_shape
|
self.expected_shape = expected_shape
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.fixes = []
|
||||||
|
|
||||||
def add_embedding_dir(self, path):
|
def add_embedding_dir(self, path):
|
||||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||||
|
|||||||
Reference in New Issue
Block a user