Fixed issue with training qwen with cached text embeds with a batch size more than 1

This commit is contained in:
Jaret Burkett
2025-08-28 08:07:12 -06:00
parent fc5b41666a
commit 9ef425a1c5

View File

@@ -210,20 +210,63 @@ class EncodedPromptPair:
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple):
def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"]):
# --- pad text_embeds ---
if isinstance(prompt_embeds[0].text_embeds, (list, tuple)):
embed_list = []
for i in range(len(prompt_embeds[0].text_embeds)):
embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0))
max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds)
padded = []
for p in prompt_embeds:
t = p.text_embeds[i]
if t.shape[1] < max_len:
pad = torch.zeros(
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
dtype=t.dtype,
device=t.device,
)
t = torch.cat([t, pad], dim=1)
padded.append(t)
embed_list.append(torch.cat(padded, dim=0))
text_embeds = embed_list
else:
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
max_len = max(p.text_embeds.shape[1] for p in prompt_embeds)
padded = []
for p in prompt_embeds:
t = p.text_embeds
if t.shape[1] < max_len:
pad = torch.zeros(
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
dtype=t.dtype,
device=t.device,
)
t = torch.cat([t, pad], dim=1)
padded.append(t)
text_embeds = torch.cat(padded, dim=0)
# --- pooled embeds ---
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
# --- attention mask ---
attention_mask = None
if prompt_embeds[0].attention_mask is not None:
attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0)
max_len = max(p.attention_mask.shape[1] for p in prompt_embeds)
padded = []
for p in prompt_embeds:
m = p.attention_mask
if m.shape[1] < max_len:
pad = torch.zeros(
(m.shape[0], max_len - m.shape[1]),
dtype=m.dtype,
device=m.device,
)
m = torch.cat([m, pad], dim=1)
padded.append(m)
attention_mask = torch.cat(padded, dim=0)
# wrap back into PromptEmbeds
pe = PromptEmbeds([text_embeds, pooled_embeds])
pe.attention_mask = attention_mask
return pe