Added ability to do prompt attn masking for flux

This commit is contained in:
Jaret Burkett
2024-09-02 17:29:36 -06:00
parent d44d4eb61a
commit e5fadddd45
3 changed files with 9 additions and 7 deletions

View File

@@ -517,6 +517,7 @@ def encode_prompts_flux(
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
attn_mask: bool = False,
):
if max_length is None:
max_length = 512
@@ -568,12 +569,9 @@ def encode_prompts_flux(
dtype = text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
# prompt_embeds = prompt_embeds * prompt_attention_mask
# _, seq_len, _ = prompt_embeds.shape
# they dont do prompt attention mask?
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
if attn_mask:
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
return prompt_embeds, pooled_prompt_embeds