diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index aaf0c94a..3a968322 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -142,8 +142,10 @@ class PromptEmbeds: text_embeds = [] pooled_embeds = None attention_mask = [] + is_list = False for key in sorted(state_dict.keys()): if key.startswith("text_embed_"): + is_list = True text_embeds.append(state_dict[key]) elif key == "text_embed": text_embeds.append(state_dict[key]) @@ -155,7 +157,7 @@ class PromptEmbeds: attention_mask.append(state_dict[key]) pe = cls(None) pe.text_embeds = text_embeds - if len(text_embeds) == 1: + if len(text_embeds) == 1 and not is_list: pe.text_embeds = text_embeds[0] if pooled_embeds is not None: pe.pooled_embeds = pooled_embeds