Added ability to do prompt attn masking for flux

This commit is contained in:
Jaret Burkett
2024-09-02 17:29:36 -06:00
parent d44d4eb61a
commit e5fadddd45
3 changed files with 9 additions and 7 deletions

View File

@@ -417,6 +417,9 @@ class ModelConfig:
# only for flux for now
self.quantize = kwargs.get("quantize", False)
self.low_vram = kwargs.get("low_vram", False)
self.attn_masking = kwargs.get("attn_masking", False)
if self.attn_masking and not self.is_flux:
raise ValueError("attn_masking is only supported with flux models currently")
pass

View File

@@ -2001,7 +2001,8 @@ class StableDiffusion:
prompt,
truncate=not long_prompts,
max_length=512,
dropout_prob=dropout_prob
dropout_prob=dropout_prob,
attn_mask=self.model_config.attn_masking
)
pe = PromptEmbeds(
prompt_embeds

View File

@@ -517,6 +517,7 @@ def encode_prompts_flux(
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
attn_mask: bool = False,
):
if max_length is None:
max_length = 512
@@ -568,12 +569,9 @@ def encode_prompts_flux(
dtype = text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
# prompt_embeds = prompt_embeds * prompt_attention_mask
# _, seq_len, _ = prompt_embeds.shape
# they dont do prompt attention mask?
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
if attn_mask:
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
return prompt_embeds, pooled_prompt_embeds