diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 541822bd..8fb9eefb 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2056,11 +2056,16 @@ class StableDiffusion: self.text_encoder, prompt, truncate=not long_prompts, - max_length=77, # todo set this higher when not transfer learning + max_length=256, dropout_prob=dropout_prob ) + + # just mask the attention mask + prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) + embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) return PromptEmbeds( embeds, + # do we want attn mask here? # attention_mask=attention_mask, )