mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-06 07:59:57 +00:00
add transformer encode
This commit is contained in:
@@ -46,16 +46,28 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"):
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||
self.embeddings.add_embedding_dir(embedding_dir)
|
||||
self.embeddings.load_textual_inversion_embeddings()
|
||||
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)
|
||||
self.text_projection = text_projection
|
||||
self.minimal_clip_skip = minimal_clip_skip
|
||||
self.clip_skip = clip_skip
|
||||
self.return_pooled = return_pooled
|
||||
self.callback_before_encode = callback_before_encode
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
|
||||
self.id_start = self.tokenizer.bos_token_id
|
||||
self.id_end = self.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
model_embeddings = text_encoder.text_model.embeddings
|
||||
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
|
||||
@@ -82,13 +94,8 @@ class ClassicTextProcessingEngine:
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.tokenizer.bos_token_id
|
||||
self.id_end = self.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
self.return_pooled = True
|
||||
|
||||
# Todo: remove these
|
||||
self.legacy_ucg_val = None # for sgm codebase
|
||||
# # Todo: remove these
|
||||
# self.legacy_ucg_val = None # for sgm codebase
|
||||
|
||||
def empty_chunk(self):
|
||||
chunk = PromptChunk()
|
||||
@@ -105,7 +112,20 @@ class ClassicTextProcessingEngine:
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
raise NotImplementedError
|
||||
self.text_encoder.transformer.text_model.embeddings.to(tokens.device)
|
||||
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
|
||||
|
||||
layer_id = - max(self.clip_skip, self.minimal_clip_skip)
|
||||
z = outputs.hidden_states[layer_id]
|
||||
|
||||
if self.return_pooled:
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
|
||||
z.pooled = pooled_output
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.text_encoder.transformer.text_model.embeddings
|
||||
@@ -215,7 +235,10 @@ class ClassicTextProcessingEngine:
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def forward(self, texts):
|
||||
def __call__(self, texts):
|
||||
if self.callback_before_encode is not None:
|
||||
self.callback_before_encode()
|
||||
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
|
||||
Reference in New Issue
Block a user