mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-22 23:39:21 +00:00
various bug fixes. Created an contextual alpha mask module to calculate alpha mask
This commit is contained in:
@@ -804,28 +804,27 @@ class StableDiffusion:
|
||||
detach_unconditional=False,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
latents = latents.to(self.device_torch)
|
||||
text_embeddings = text_embeddings.to(self.device_torch)
|
||||
timestep = timestep.to(self.device_torch)
|
||||
|
||||
Reference in New Issue
Block a user