Fixed issue with DOP when using Z-Image

This commit is contained in:
Jaret Burkett
2025-11-28 09:36:21 -07:00
parent 08a39754a4
commit d42f5af2fc
3 changed files with 11 additions and 5 deletions

View File

@@ -803,9 +803,8 @@ class BaseModel:
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if isinstance(text_embeddings.text_embeds, list):
if len(text_embeddings.text_embeds) == latents.shape[0]:
if len(text_embeddings.text_embeds[0].shape) == 2:
# handle list of embeddings
te_batch_size = len(text_embeddings.text_embeds)
else:

View File

@@ -87,7 +87,10 @@ class PromptEmbeds:
def expand_to_batch(self, batch_size):
pe = self.clone()
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
current_batch_size = pe.text_embeds[0].shape[0]
if len(pe.text_embeds[0].shape) == 2:
current_batch_size = len(pe.text_embeds)
else:
current_batch_size = pe.text_embeds[0].shape[0]
else:
current_batch_size = pe.text_embeds.shape[0]
if current_batch_size == batch_size:
@@ -95,7 +98,11 @@ class PromptEmbeds:
if current_batch_size != 1:
raise Exception("Can only expand batch size for batch size 1")
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds]
if len(pe.text_embeds[0].shape) == 2:
# batch is a list of tensors
pe.text_embeds = pe.text_embeds * batch_size
else:
pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds]
else:
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if pe.pooled_embeds is not None: