mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with training qwen with cached text embeds with a batch size more than 1
This commit is contained in:
@@ -210,20 +210,63 @@ class EncodedPromptPair:
|
||||
return self
|
||||
|
||||
|
||||
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||
if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple):
|
||||
def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"]):
|
||||
# --- pad text_embeds ---
|
||||
if isinstance(prompt_embeds[0].text_embeds, (list, tuple)):
|
||||
embed_list = []
|
||||
for i in range(len(prompt_embeds[0].text_embeds)):
|
||||
embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0))
|
||||
max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds)
|
||||
padded = []
|
||||
for p in prompt_embeds:
|
||||
t = p.text_embeds[i]
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.zeros(
|
||||
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
|
||||
dtype=t.dtype,
|
||||
device=t.device,
|
||||
)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
padded.append(t)
|
||||
embed_list.append(torch.cat(padded, dim=0))
|
||||
text_embeds = embed_list
|
||||
else:
|
||||
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
||||
max_len = max(p.text_embeds.shape[1] for p in prompt_embeds)
|
||||
padded = []
|
||||
for p in prompt_embeds:
|
||||
t = p.text_embeds
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.zeros(
|
||||
(t.shape[0], max_len - t.shape[1], *t.shape[2:]),
|
||||
dtype=t.dtype,
|
||||
device=t.device,
|
||||
)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
padded.append(t)
|
||||
text_embeds = torch.cat(padded, dim=0)
|
||||
|
||||
# --- pooled embeds ---
|
||||
pooled_embeds = None
|
||||
if prompt_embeds[0].pooled_embeds is not None:
|
||||
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
||||
|
||||
# --- attention mask ---
|
||||
attention_mask = None
|
||||
if prompt_embeds[0].attention_mask is not None:
|
||||
attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0)
|
||||
max_len = max(p.attention_mask.shape[1] for p in prompt_embeds)
|
||||
padded = []
|
||||
for p in prompt_embeds:
|
||||
m = p.attention_mask
|
||||
if m.shape[1] < max_len:
|
||||
pad = torch.zeros(
|
||||
(m.shape[0], max_len - m.shape[1]),
|
||||
dtype=m.dtype,
|
||||
device=m.device,
|
||||
)
|
||||
m = torch.cat([m, pad], dim=1)
|
||||
padded.append(m)
|
||||
attention_mask = torch.cat(padded, dim=0)
|
||||
|
||||
# wrap back into PromptEmbeds
|
||||
pe = PromptEmbeds([text_embeds, pooled_embeds])
|
||||
pe.attention_mask = attention_mask
|
||||
return pe
|
||||
|
||||
Reference in New Issue
Block a user