diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index b0558b8c..b7e27cf9 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -210,20 +210,63 @@ class EncodedPromptPair: return self -def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): - if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple): +def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"]): + # --- pad text_embeds --- + if isinstance(prompt_embeds[0].text_embeds, (list, tuple)): embed_list = [] for i in range(len(prompt_embeds[0].text_embeds)): - embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0)) + max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + t = p.text_embeds[i] + if t.shape[1] < max_len: + pad = torch.zeros( + (t.shape[0], max_len - t.shape[1], *t.shape[2:]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=1) + padded.append(t) + embed_list.append(torch.cat(padded, dim=0)) text_embeds = embed_list else: - text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + max_len = max(p.text_embeds.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + t = p.text_embeds + if t.shape[1] < max_len: + pad = torch.zeros( + (t.shape[0], max_len - t.shape[1], *t.shape[2:]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=1) + padded.append(t) + text_embeds = torch.cat(padded, dim=0) + + # --- pooled embeds --- pooled_embeds = None if prompt_embeds[0].pooled_embeds is not None: pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) + + # --- attention mask --- attention_mask = None if prompt_embeds[0].attention_mask is not None: - attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0) + max_len = max(p.attention_mask.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + m = p.attention_mask + if m.shape[1] < max_len: + pad = torch.zeros( + (m.shape[0], max_len - m.shape[1]), + dtype=m.dtype, + device=m.device, + ) + m = torch.cat([m, pad], dim=1) + padded.append(m) + attention_mask = torch.cat(padded, dim=0) + + # wrap back into PromptEmbeds pe = PromptEmbeds([text_embeds, pooled_embeds]) pe.attention_mask = attention_mask return pe