mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added back syustem prompt for llm and remove those tokens from the embeddings
This commit is contained in:
@@ -44,9 +44,25 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||||
self.llm_ref: weakref.ref = weakref.ref(llm)
|
self.llm_ref: weakref.ref = weakref.ref(llm)
|
||||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||||
|
# make sure we can pad
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
self.system_prompt = ""
|
# self.system_prompt = ""
|
||||||
# self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||||
|
|
||||||
|
# determine length of system prompt
|
||||||
|
sys_prompt_tokenized = tokenizer(
|
||||||
|
[self.system_prompt],
|
||||||
|
padding="longest",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0]
|
||||||
|
|
||||||
|
self.system_prompt_length = sys_prompt_tokenized_ids.shape[0]
|
||||||
|
|
||||||
|
print(f"System prompt length: {self.system_prompt_length}")
|
||||||
|
|
||||||
self.hidden_size = llm.config.hidden_size
|
self.hidden_size = llm.config.hidden_size
|
||||||
|
|
||||||
@@ -81,12 +97,15 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
text_inputs = tokenizer(
|
text_inputs = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=max_sequence_length,
|
max_length=max_sequence_length + self.system_prompt_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
text_input_ids = text_inputs.input_ids.to(device)
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
|
||||||
|
# remove the system prompt from the input
|
||||||
|
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
||||||
|
|
||||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||||
prompt_embeds = text_encoder(
|
prompt_embeds = text_encoder(
|
||||||
@@ -110,8 +129,8 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
prompt = prompt if isinstance(prompt, list) else [prompt]
|
prompt = prompt if isinstance(prompt, list) else [prompt]
|
||||||
|
|
||||||
# prompt = [self.system_prompt + " <Prompt Start> " + p for p in prompt]
|
|
||||||
prompt = [self.system_prompt + p for p in prompt]
|
prompt = [self.system_prompt + p for p in prompt]
|
||||||
|
# prompt = [self.system_prompt + p for p in prompt]
|
||||||
|
|
||||||
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
|
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
Reference in New Issue
Block a user