WIP Flex 2 pipeline

This commit is contained in:
Jaret Burkett
2025-02-16 14:54:29 -07:00
parent 87e557cf1e
commit 1f7784510d
2 changed files with 161 additions and 1 deletions

View File

@@ -151,7 +151,7 @@ class LLMAdapter(torch.nn.Module):
prompt_embeds = text_encoder(
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
)
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds.hidden_states[-1]
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]