mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added back syustem prompt for llm and remove those tokens from the embeddings
This commit is contained in:
@@ -44,9 +44,25 @@ class LLMAdapter(torch.nn.Module):
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.llm_ref: weakref.ref = weakref.ref(llm)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
# make sure we can pad
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
self.system_prompt = ""
|
||||
# self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||
# self.system_prompt = ""
|
||||
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||
|
||||
# determine length of system prompt
|
||||
sys_prompt_tokenized = tokenizer(
|
||||
[self.system_prompt],
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0]
|
||||
|
||||
self.system_prompt_length = sys_prompt_tokenized_ids.shape[0]
|
||||
|
||||
print(f"System prompt length: {self.system_prompt_length}")
|
||||
|
||||
self.hidden_size = llm.config.hidden_size
|
||||
|
||||
@@ -81,12 +97,15 @@ class LLMAdapter(torch.nn.Module):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
max_length=max_sequence_length + self.system_prompt_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
|
||||
# remove the system prompt from the input
|
||||
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = text_encoder(
|
||||
@@ -110,8 +129,8 @@ class LLMAdapter(torch.nn.Module):
|
||||
|
||||
prompt = prompt if isinstance(prompt, list) else [prompt]
|
||||
|
||||
# prompt = [self.system_prompt + " <Prompt Start> " + p for p in prompt]
|
||||
prompt = [self.system_prompt + p for p in prompt]
|
||||
# prompt = [self.system_prompt + p for p in prompt]
|
||||
|
||||
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
|
||||
prompt=prompt,
|
||||
|
||||
Reference in New Issue
Block a user