This commit is contained in:
layerdiffusion
2024-08-04 13:19:32 -07:00
parent cb9b155645
commit 8c118df739

View File

@@ -46,7 +46,7 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
class ClassicTextProcessingEngine:
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768):
def __init__(self, text_encoder, tokenizer, chunk_length=75, embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original"):
super().__init__()
self.chunk_length = chunk_length
@@ -55,6 +55,7 @@ class ClassicTextProcessingEngine:
self.embeddings.load_textual_inversion_embeddings()
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.emphasis = emphasis.get_current_option(emphasis_name)
model_embeddings = text_encoder.text_model.embeddings
model_embeddings.token_embedding = CLIPEmbeddingForTextualInversion(model_embeddings.token_embedding, self.embeddings, textual_inversion_key=embedding_key)
@@ -84,6 +85,7 @@ class ClassicTextProcessingEngine:
self.id_start = self.tokenizer.bos_token_id
self.id_end = self.tokenizer.eos_token_id
self.id_pad = self.id_end
self.return_pooled = True
# Todo: remove these
self.legacy_ucg_val = None # for sgm codebase
@@ -98,7 +100,7 @@ class ClassicTextProcessingEngine:
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
@@ -106,8 +108,8 @@ class ClassicTextProcessingEngine:
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedding_layer = self.text_encoder.transformer.text_model.embeddings
ids = self.text_encoder.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
@@ -122,8 +124,6 @@ class ClassicTextProcessingEngine:
last_comma = -1
def next_chunk(is_last=False):
"""puts current chunk into the list of results and produces the next one - empty;
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
@@ -175,7 +175,7 @@ class ClassicTextProcessingEngine:
if len(chunk.tokens) == self.chunk_length:
next_chunk()
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
embedding, embedding_length_in_tokens = self.embeddings.find_embedding_at_position(tokens, position)
if embedding is None:
chunk.tokens.append(token)
chunk.multipliers.append(weight)
@@ -227,9 +227,9 @@ class ClassicTextProcessingEngine:
tokens = [x.tokens for x in batch_chunk]
multipliers = [x.multipliers for x in batch_chunk]
self.hijack.fixes = [x.fixes for x in batch_chunk]
self.embeddings.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
for fixes in self.embeddings.fixes:
for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding
@@ -276,15 +276,11 @@ class ClassicTextProcessingEngine:
pooled = getattr(z, 'pooled', None)
# Todo
# e = emphasis.get_current_option(opts.emphasis)()
e = emphasis.EmphasisOriginal()
e.tokens = remade_batch_tokens
e.multipliers = torch.asarray(batch_multipliers)
e.z = z
e.after_transformers()
z = e.z
self.emphasis.tokens = remade_batch_tokens
self.emphasis.multipliers = torch.asarray(batch_multipliers)
self.emphasis.z = z
self.emphasis.after_transformers()
z = self.emphasis.z
if pooled is not None:
z.pooled = pooled