mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to do prompt attn masking for flux
This commit is contained in:
@@ -417,6 +417,9 @@ class ModelConfig:
|
||||
# only for flux for now
|
||||
self.quantize = kwargs.get("quantize", False)
|
||||
self.low_vram = kwargs.get("low_vram", False)
|
||||
self.attn_masking = kwargs.get("attn_masking", False)
|
||||
if self.attn_masking and not self.is_flux:
|
||||
raise ValueError("attn_masking is only supported with flux models currently")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -2001,7 +2001,8 @@ class StableDiffusion:
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=512,
|
||||
dropout_prob=dropout_prob
|
||||
dropout_prob=dropout_prob,
|
||||
attn_mask=self.model_config.attn_masking
|
||||
)
|
||||
pe = PromptEmbeds(
|
||||
prompt_embeds
|
||||
|
||||
@@ -517,6 +517,7 @@ def encode_prompts_flux(
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
attn_mask: bool = False,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = 512
|
||||
@@ -568,12 +569,9 @@ def encode_prompts_flux(
|
||||
dtype = text_encoder[1].dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
# prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
# _, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# they dont do prompt attention mask?
|
||||
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
|
||||
if attn_mask:
|
||||
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user