From d42f5af2fcf8fbb803f66886b5a13349f4fd563a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 28 Nov 2025 09:36:21 -0700 Subject: [PATCH] Fixed issue with DOP when using Z-Image --- toolkit/models/base_model.py | 3 +-- toolkit/prompt_utils.py | 11 +++++++++-- version.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index dd290d5f..9671a8c4 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -803,9 +803,8 @@ class BaseModel: # then we are doing it, otherwise we are not and takes half the time. do_classifier_free_guidance = True - # check if batch size of embeddings matches batch size of latents if isinstance(text_embeddings.text_embeds, list): - if len(text_embeddings.text_embeds) == latents.shape[0]: + if len(text_embeddings.text_embeds[0].shape) == 2: # handle list of embeddings te_batch_size = len(text_embeddings.text_embeds) else: diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 3a968322..0bcbe876 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -87,7 +87,10 @@ class PromptEmbeds: def expand_to_batch(self, batch_size): pe = self.clone() if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): - current_batch_size = pe.text_embeds[0].shape[0] + if len(pe.text_embeds[0].shape) == 2: + current_batch_size = len(pe.text_embeds) + else: + current_batch_size = pe.text_embeds[0].shape[0] else: current_batch_size = pe.text_embeds.shape[0] if current_batch_size == batch_size: @@ -95,7 +98,11 @@ class PromptEmbeds: if current_batch_size != 1: raise Exception("Can only expand batch size for batch size 1") if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): - pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds] + if len(pe.text_embeds[0].shape) == 2: + # batch is a list of tensors + pe.text_embeds = pe.text_embeds * batch_size + else: + pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds] else: pe.text_embeds = pe.text_embeds.expand(batch_size, -1) if pe.pooled_embeds is not None: diff --git a/version.py b/version.py index 5a921ca6..13588ece 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.6" \ No newline at end of file +VERSION = "0.7.7" \ No newline at end of file