diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2b844324..deb4d68b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -417,6 +417,9 @@ class ModelConfig: # only for flux for now self.quantize = kwargs.get("quantize", False) self.low_vram = kwargs.get("low_vram", False) + self.attn_masking = kwargs.get("attn_masking", False) + if self.attn_masking and not self.is_flux: + raise ValueError("attn_masking is only supported with flux models currently") pass diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index db1f8e81..92ea2de7 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2001,7 +2001,8 @@ class StableDiffusion: prompt, truncate=not long_prompts, max_length=512, - dropout_prob=dropout_prob + dropout_prob=dropout_prob, + attn_mask=self.model_config.attn_masking ) pe = PromptEmbeds( prompt_embeds diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 75e5b923..592f7cc6 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -517,6 +517,7 @@ def encode_prompts_flux( truncate: bool = True, max_length=None, dropout_prob=0.0, + attn_mask: bool = False, ): if max_length is None: max_length = 512 @@ -568,12 +569,9 @@ def encode_prompts_flux( dtype = text_encoder[1].dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - # prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) - # prompt_embeds = prompt_embeds * prompt_attention_mask - # _, seq_len, _ = prompt_embeds.shape - - # they dont do prompt attention mask? - # prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device) + if attn_mask: + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) return prompt_embeds, pooled_prompt_embeds