From 2be692639864d4238a8e3c2483998f7b1e036ad4 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 14 Feb 2025 07:23:37 -0700 Subject: [PATCH] Added back syustem prompt for llm and remove those tokens from the embeddings --- toolkit/models/llm_adapter.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py index 8a77eefe..51880ffa 100644 --- a/toolkit/models/llm_adapter.py +++ b/toolkit/models/llm_adapter.py @@ -44,9 +44,25 @@ class LLMAdapter(torch.nn.Module): self.sd_ref: weakref.ref = weakref.ref(sd) self.llm_ref: weakref.ref = weakref.ref(llm) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + # make sure we can pad + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token - self.system_prompt = "" - # self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " + # self.system_prompt = "" + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " + + # determine length of system prompt + sys_prompt_tokenized = tokenizer( + [self.system_prompt], + padding="longest", + return_tensors="pt", + ) + + sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0] + + self.system_prompt_length = sys_prompt_tokenized_ids.shape[0] + + print(f"System prompt length: {self.system_prompt_length}") self.hidden_size = llm.config.hidden_size @@ -81,12 +97,15 @@ class LLMAdapter(torch.nn.Module): text_inputs = tokenizer( prompt, padding="max_length", - max_length=max_sequence_length, + max_length=max_sequence_length + self.system_prompt_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) + + # remove the system prompt from the input + text_input_ids = text_input_ids[:, self.system_prompt_length:] prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = text_encoder( @@ -110,8 +129,8 @@ class LLMAdapter(torch.nn.Module): prompt = prompt if isinstance(prompt, list) else [prompt] - # prompt = [self.system_prompt + " " + p for p in prompt] prompt = [self.system_prompt + p for p in prompt] + # prompt = [self.system_prompt + p for p in prompt] prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( prompt=prompt,