added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -537,6 +537,7 @@ def encode_prompts_xl(
use_text_encoder_2: bool = True, # sdxl
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
@@ -553,6 +554,12 @@ def encode_prompts_xl(
if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts]
if dropout_prob > 0.0:
# randomly drop out prompts
prompt_list_to_use = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
# set the max length for the next one
if idx == 0:
@@ -598,9 +605,17 @@ def encode_prompts(
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
if max_length is None:
max_length = tokenizer.model_max_length
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)