various bug fixes. Created an contextual alpha mask module to calculate alpha mask

This commit is contained in:
Jaret Burkett
2024-01-18 16:34:27 -07:00
parent 86c70a2a1f
commit f17ad8d794
7 changed files with 93 additions and 28 deletions

View File

@@ -804,28 +804,27 @@ class StableDiffusion:
detach_unconditional=False,
**kwargs,
):
with torch.no_grad():
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = concat_prompt_embeds([
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
])
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = concat_prompt_embeds([
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
])
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)
text_embeddings = text_embeddings.to(self.device_torch)
timestep = timestep.to(self.device_torch)