Apply a mask to the embeds for SD if using T5 encoder

This commit is contained in:
Jaret Burkett
2024-10-04 10:55:20 -06:00
parent a800c9d19e
commit 9452929300

View File

@@ -2056,11 +2056,16 @@ class StableDiffusion:
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=77, # todo set this higher when not transfer learning
max_length=256,
dropout_prob=dropout_prob
)
# just mask the attention mask
prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape)
embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device)
return PromptEmbeds(
embeds,
# do we want attn mask here?
# attention_mask=attention_mask,
)