mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Added ability to do prompt attn masking for flux
This commit is contained in:
@@ -517,6 +517,7 @@ def encode_prompts_flux(
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
attn_mask: bool = False,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = 512
|
||||
@@ -568,12 +569,9 @@ def encode_prompts_flux(
|
||||
dtype = text_encoder[1].dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
# prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
# _, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# they dont do prompt attention mask?
|
||||
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
|
||||
if attn_mask:
|
||||
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user