mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 14:09:49 +00:00
Fixed issue on z-image that prevented training at a larger batch size
This commit is contained in:
@@ -247,25 +247,9 @@ class EncodedPromptPair:
|
||||
def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"], padding_side: str = "right") -> PromptEmbeds:
|
||||
# --- pad text_embeds ---
|
||||
if isinstance(prompt_embeds[0].text_embeds, (list, tuple)):
|
||||
embed_list = []
|
||||
for i in range(len(prompt_embeds[0].text_embeds)):
|
||||
max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds)
|
||||
padded = []
|
||||
for p in prompt_embeds:
|
||||
t = p.text_embeds[i]
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.zeros(
|
||||
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
|
||||
dtype=t.dtype,
|
||||
device=t.device,
|
||||
)
|
||||
if padding_side == "right":
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
else:
|
||||
t = torch.cat([pad, t], dim=1)
|
||||
padded.append(t)
|
||||
embed_list.append(torch.cat(padded, dim=0))
|
||||
text_embeds = embed_list
|
||||
text_embeds = []
|
||||
for p in prompt_embeds:
|
||||
text_embeds += p.text_embeds
|
||||
else:
|
||||
max_len = max(p.text_embeds.shape[1] for p in prompt_embeds)
|
||||
padded = []
|
||||
|
||||
Reference in New Issue
Block a user