mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -537,6 +537,7 @@ def encode_prompts_xl(
|
||||
use_text_encoder_2: bool = True, # sdxl
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
@@ -553,6 +554,12 @@ def encode_prompts_xl(
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompt_list_to_use = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
|
||||
]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||
# set the max length for the next one
|
||||
if idx == 0:
|
||||
@@ -598,9 +605,17 @@ def encode_prompts(
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompts = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||
]
|
||||
|
||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user