diff --git a/backend/text_processing/engine.py b/backend/text_processing/engine.py index 07cd2949..241d70cf 100644 --- a/backend/text_processing/engine.py +++ b/backend/text_processing/engine.py @@ -46,16 +46,28 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): class ClassicTextProcessingEngine: - def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"): + def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=None, minimal_clip_skip=1, clip_skip=1, return_pooled=False, callback_before_encode=None): super().__init__() - self.chunk_length = chunk_length self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape) self.embeddings.add_embedding_dir(embedding_dir) self.embeddings.load_textual_inversion_embeddings() + self.text_encoder = text_encoder self.tokenizer = tokenizer + self.emphasis = emphasis.get_current_option(emphasis_name) + self.text_projection = text_projection + self.minimal_clip_skip = minimal_clip_skip + self.clip_skip = clip_skip + self.return_pooled = return_pooled + self.callback_before_encode = callback_before_encode + + self.chunk_length = chunk_length + + self.id_start = self.tokenizer.bos_token_id + self.id_end = self.tokenizer.eos_token_id + self.id_pad = self.id_end model_embeddings = text_encoder.text_model.embeddings model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key) @@ -82,13 +94,8 @@ class ClassicTextProcessingEngine: if mult != 1.0: self.token_mults[ident] = mult - self.id_start = self.tokenizer.bos_token_id - self.id_end = self.tokenizer.eos_token_id - self.id_pad = self.id_end - self.return_pooled = True - - # Todo: remove these - self.legacy_ucg_val = None # for sgm codebase + # # Todo: remove these + # self.legacy_ucg_val = None # for sgm codebase def empty_chunk(self): chunk = PromptChunk() @@ -105,7 +112,20 @@ class ClassicTextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - raise NotImplementedError + self.text_encoder.transformer.text_model.embeddings.to(tokens.device) + outputs = self.text_encoder.transformer(tokens, output_hidden_states=True) + + layer_id = - max(self.clip_skip, self.minimal_clip_skip) + z = outputs.hidden_states[layer_id] + + if self.return_pooled: + pooled_output = outputs.pooler_output + + if self.text_projection: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + + z.pooled = pooled_output + return z def encode_embedding_init_text(self, init_text, nvpt): embedding_layer = self.text_encoder.transformer.text_model.embeddings @@ -215,7 +235,10 @@ class ClassicTextProcessingEngine: return batch_chunks, token_count - def forward(self, texts): + def __call__(self, texts): + if self.callback_before_encode is not None: + self.callback_before_encode() + batch_chunks, token_count = self.process_texts(texts) used_embeddings = {}