mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with token replacements
This commit is contained in:
@@ -155,11 +155,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
prompt, add_if_not_present=False
|
||||
)
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt, self.trigger_word
|
||||
prompt, self.trigger_word, add_if_not_present=False
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
@@ -363,15 +363,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None and not is_reg:
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=self.trigger_word,
|
||||
add_if_not_present=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class Embedding:
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
|
||||
@@ -614,6 +614,7 @@ class StableDiffusion:
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
# if it is not a list, make it one
|
||||
if not isinstance(prompt, list):
|
||||
|
||||
Reference in New Issue
Block a user