Fixed issue on z-image that prevented training at a larger batch size

This commit is contained in:
Jaret Burkett
2026-03-10 15:43:25 -06:00
parent 4909b809c7
commit 35b1cde3cb

View File

@@ -247,25 +247,9 @@ class EncodedPromptPair:
def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"], padding_side: str = "right") -> PromptEmbeds:
# --- pad text_embeds ---
if isinstance(prompt_embeds[0].text_embeds, (list, tuple)):
embed_list = []
for i in range(len(prompt_embeds[0].text_embeds)):
max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds)
padded = []
for p in prompt_embeds:
t = p.text_embeds[i]
if t.shape[1] < max_len:
pad = torch.zeros(
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
dtype=t.dtype,
device=t.device,
)
if padding_side == "right":
t = torch.cat([t, pad], dim=1)
else:
t = torch.cat([pad, t], dim=1)
padded.append(t)
embed_list.append(torch.cat(padded, dim=0))
text_embeds = embed_list
text_embeds = []
for p in prompt_embeds:
text_embeds += p.text_embeds
else:
max_len = max(p.text_embeds.shape[1] for p in prompt_embeds)
padded = []