diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 0f3c10a9..ef3a7da8 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -247,25 +247,9 @@ class EncodedPromptPair: def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"], padding_side: str = "right") -> PromptEmbeds: # --- pad text_embeds --- if isinstance(prompt_embeds[0].text_embeds, (list, tuple)): - embed_list = [] - for i in range(len(prompt_embeds[0].text_embeds)): - max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds) - padded = [] - for p in prompt_embeds: - t = p.text_embeds[i] - if t.shape[1] < max_len: - pad = torch.zeros( - (t.shape[0], max_len - t.shape[1], *t.shape[2:]), - dtype=t.dtype, - device=t.device, - ) - if padding_side == "right": - t = torch.cat([t, pad], dim=1) - else: - t = torch.cat([pad, t], dim=1) - padded.append(t) - embed_list.append(torch.cat(padded, dim=0)) - text_embeds = embed_list + text_embeds = [] + for p in prompt_embeds: + text_embeds += p.text_embeds else: max_len = max(p.text_embeds.shape[1] for p in prompt_embeds) padded = []