add transformer encode

This commit is contained in:
layerdiffusion
2024-08-04 13:45:32 -07:00
parent 8c118df739
commit 1e23bf07ca

View File

@@ -46,16 +46,28 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
class ClassicTextProcessingEngine:
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"):
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None):
super().__init__()
self.chunk_length = chunk_length
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
self.embeddings.add_embedding_dir(embedding_dir)
self.embeddings.load_textual_inversion_embeddings()
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.emphasis = emphasis.get_current_option(emphasis_name)
self.text_projection = text_projection
self.minimal_clip_skip = minimal_clip_skip
self.clip_skip = clip_skip
self.return_pooled = return_pooled
self.callback_before_encode = callback_before_encode
self.chunk_length = chunk_length
self.id_start = self.tokenizer.bos_token_id
self.id_end = self.tokenizer.eos_token_id
self.id_pad = self.id_end
model_embeddings = text_encoder.text_model.embeddings
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
@@ -82,13 +94,8 @@ class ClassicTextProcessingEngine:
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.tokenizer.bos_token_id
self.id_end = self.tokenizer.eos_token_id
self.id_pad = self.id_end
self.return_pooled = True
# Todo: remove these
self.legacy_ucg_val = None # for sgm codebase
# # Todo: remove these
# self.legacy_ucg_val = None # for sgm codebase
def empty_chunk(self):
chunk = PromptChunk()
@@ -105,7 +112,20 @@ class ClassicTextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
raise NotImplementedError
self.text_encoder.transformer.text_model.embeddings.to(tokens.device)
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)
layer_id = - max(self.clip_skip, self.minimal_clip_skip)
z = outputs.hidden_states[layer_id]
if self.return_pooled:
pooled_output = outputs.pooler_output
if self.text_projection:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
z.pooled = pooled_output
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.text_encoder.transformer.text_model.embeddings
@@ -215,7 +235,10 @@ class ClassicTextProcessingEngine:
return batch_chunks, token_count
def forward(self, texts):
def __call__(self, texts):
if self.callback_before_encode is not None:
self.callback_before_encode()
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}