mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-26 09:43:56 +00:00
add embedding layer impl
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import math
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from textual_inversion import EmbeddingDatabase
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
@@ -16,28 +16,48 @@ class PromptChunk:
|
||||
self.fixes = []
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.chunk_length = 75
|
||||
|
||||
self.is_trainable = False
|
||||
self.input_key = 'txt'
|
||||
self.return_pooled = False
|
||||
|
||||
self.comma_token = None
|
||||
|
||||
self.hijack = hijack
|
||||
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
self.weight = self.wrapped.weight
|
||||
|
||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
self.embeddings.fixes = None
|
||||
|
||||
self.legacy_ucg_val = None # for sgm codebase
|
||||
inputs_embeds = self.wrapped(input_ids)
|
||||
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||
self.embeddings.add_embedding_dir(embedding_dir)
|
||||
self.embeddings.load_textual_inversion_embeddings()
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
model_embeddings = text_encoder.text_model.embeddings
|
||||
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
@@ -61,10 +81,13 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_start = self.tokenizer.bos_token_id
|
||||
self.id_end = self.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
# Todo: remove these
|
||||
self.legacy_ucg_val = None # for sgm codebase
|
||||
|
||||
def empty_chunk(self):
|
||||
chunk = PromptChunk()
|
||||
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||
|
||||
Reference in New Issue
Block a user