diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d96f9b46..c6b516de 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -155,11 +155,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here if self.embedding is not None: prompt = self.embedding.inject_embedding_to_prompt( - prompt, + prompt, add_if_not_present=False ) if self.trigger_word is not None: prompt = self.sd.inject_trigger_into_prompt( - prompt, self.trigger_word + prompt, self.trigger_word, add_if_not_present=False ) gen_img_config_list.append(GenerateImageConfig( @@ -363,15 +363,15 @@ class BaseSDTrainProcess(BaseTrainProcess): prompt = self.embedding.inject_embedding_to_prompt( prompt, expand_token=True, - add_if_not_present=True, + add_if_not_present=not is_reg, ) # make sure trigger is in the prompts if not a regularization run - if self.trigger_word is not None and not is_reg: + if self.trigger_word is not None: prompt = self.sd.inject_trigger_into_prompt( prompt, trigger=self.trigger_word, - add_if_not_present=True, + add_if_not_present=not is_reg, ) conditioned_prompts.append(prompt) diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 3bc0a68e..b326dde3 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -159,7 +159,7 @@ class Embedding: output_prompt = output_prompt.replace(to_replace, replace_with) # see how many times replace_with is in the prompt - num_instances = prompt.count(replace_with) + num_instances = output_prompt.count(replace_with) if num_instances == 0 and add_if_not_present: # add it to the beginning of the prompt diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 34ccb396..5e773acc 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -614,6 +614,7 @@ class StableDiffusion: return latents def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) prompt = prompt # if it is not a list, make it one if not isinstance(prompt, list):