Added back syustem prompt for llm and remove those tokens from the embeddings

This commit is contained in:
Jaret Burkett
2025-02-14 07:23:37 -07:00
parent 87ac031859
commit 2be6926398

View File

@@ -44,9 +44,25 @@ class LLMAdapter(torch.nn.Module):
self.sd_ref: weakref.ref = weakref.ref(sd)
self.llm_ref: weakref.ref = weakref.ref(llm)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
# make sure we can pad
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.system_prompt = ""
# self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
# self.system_prompt = ""
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
# determine length of system prompt
sys_prompt_tokenized = tokenizer(
[self.system_prompt],
padding="longest",
return_tensors="pt",
)
sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0]
self.system_prompt_length = sys_prompt_tokenized_ids.shape[0]
print(f"System prompt length: {self.system_prompt_length}")
self.hidden_size = llm.config.hidden_size
@@ -81,12 +97,15 @@ class LLMAdapter(torch.nn.Module):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
max_length=max_sequence_length + self.system_prompt_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
# remove the system prompt from the input
text_input_ids = text_input_ids[:, self.system_prompt_length:]
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
@@ -110,8 +129,8 @@ class LLMAdapter(torch.nn.Module):
prompt = prompt if isinstance(prompt, list) else [prompt]
# prompt = [self.system_prompt + " <Prompt Start> " + p for p in prompt]
prompt = [self.system_prompt + p for p in prompt]
# prompt = [self.system_prompt + p for p in prompt]
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
prompt=prompt,